diff --git a/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp b/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp index e9a02879051ae3..5c60cf6bfae5e7 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp @@ -11,6 +11,8 @@ #include "low_precision/kv_cache_concat.hpp" #include "low_precision/low_precision.hpp" #include "low_precision/move_fake_convert_up_through_kv_cache_concat.hpp" +#include "moe_transformations/device_routed_moe_transform.hpp" +#include "moe_transformations/gather_to_2d_gather.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/greater.hpp" #include "openvino/op/group_query_attention.hpp" @@ -1359,29 +1361,29 @@ bool is_moe_model(const std::shared_ptr& model) { return false; } -// Apply MoE-specific optimizations to stage configuration based on hint -void apply_moe_optimizations(ov::AnyMap& stage_config, - ::intel_npu::npuw::llm::MoEHint moe_hint, - const std::string& stage_name) { - // MoE expert and router pattern isolation options - const ov::AnyMap expert_opts = { - {"NPUW_ONLINE_PIPELINE", "REP"}, - {"NPUW_ONLINE_ISOLATE", "MOE"}, - {"NPUW_ONLINE_KEEP_BLOCK_SIZE", "4"}, - {"NPUW_UNFOLD_IREQS", "NO"}, - }; - +// Apply MoE-specific configuration based on hint +void apply_moe_config(ov::AnyMap& stage_config, + ::intel_npu::npuw::llm::MoEHint moe_hint, + const std::string& stage_name) { if (moe_hint == ::intel_npu::npuw::llm::MoEHint::HOST_ROUTED) { - LOG_INFO("MoE architecture optimization for " << stage_name - << " stage: HOST_ROUTED (host-side expert routing)"); + LOG_INFO("MoE config for " << stage_name << " stage: HOST_ROUTED (host-side expert routing)"); + // MoE expert and router pattern isolation options + const ov::AnyMap expert_opts = { + {"NPUW_ONLINE_PIPELINE", "REP"}, + {"NPUW_ONLINE_ISOLATE", "MOE"}, + {"NPUW_ONLINE_KEEP_BLOCK_SIZE", "4"}, + {"NPUW_UNFOLD_IREQS", "NO"}, + }; merge_config_with(stage_config, expert_opts); } else if (moe_hint == ::intel_npu::npuw::llm::MoEHint::DEVICE_ROUTED) { - NPUW_ASSERT(false && "MoE DEVICE_ROUTED is not yet implemented! " - "DEVICE_ROUTED will use in-graph gather-based expert selection to avoid " - "graph splitting and reduce host-device communication overhead. " - "This feature is planned for future releases."); + if (stage_name == "PREFILL") { + NPUW_ASSERT(false && "MoE DEVICE_ROUTED is not supported for PREFILL stage. " + "DEVICE_ROUTED mode uses in-graph gather-based expert selection which is only " + "optimized for GENERATE stage. Please use HOST_ROUTED or DENSE for PREFILL."); + } + stage_config["NPUW_UNFOLD_IREQS"] = "NO"; } else if (moe_hint == ::intel_npu::npuw::llm::MoEHint::DENSE) { - LOG_INFO("MoE architecture optimization for " << stage_name << " stage: DENSE (all experts active)"); + LOG_INFO("MoE config for " << stage_name << " stage: DENSE (all experts active)"); // DENSE mode requires CPU-only device due to extremely long NPU compilation time and high resource consumption auto npuw_devices = stage_config.count("NPUW_DEVICES") ? stage_config.at("NPUW_DEVICES").as() : "NPU"; @@ -1392,6 +1394,23 @@ void apply_moe_optimizations(ov::AnyMap& stage_config, } } +// Apply DEVICE_ROUTED MoE transformations to models +void apply_moe_device_routed_transforms(std::vector>& model_variants) { + LOG_INFO("Applying DEVICE_ROUTED MoE transformations..."); + ov::npuw::pass::DeviceRoutedMoETransform moe_transform; + ov::npuw::pass::GatherTo2DGather gather_transform; + + for (auto& model : model_variants) { + moe_transform.run_on_model(model); + LOG_DEBUG(" Applied DEVICE_ROUTED transformations to model variant"); + + // Apply Gather to 2D Gather transformation for HW optimization + gather_transform.run_on_model(model); + LOG_DEBUG(" Applied GatherTo2DGather transformation to model variant"); + } + LOG_INFO("DEVICE_ROUTED MoE transformations completed"); +} + } // namespace void ov::npuw::LLMCompiledModel::convert_stateful_lora_to_stateless(std::shared_ptr& model) { @@ -1601,6 +1620,18 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr& m LOG_INFO("Eagle3 speculative decoding mode enabled"); } + // Auto-detect MoE model by scanning for router/expert nodes + const bool is_moe = is_moe_model(model); + if (is_moe) { + // Only apply MoE defaults if not explicitly set in external config + if (npuw_llm_props.find("NPUW_LLM_SHARED_HEAD") == npuw_llm_props.end()) { + m_cfg.update({{"NPUW_LLM_SHARED_HEAD", "NO"}}); + } + if (npuw_llm_props.find("NPUW_LLM_GENERATE_HINT") == npuw_llm_props.end()) { + m_cfg.update({{"NPUW_LLM_GENERATE_HINT", "BEST_PERF"}}); + } + } + // NB: PREFILL_HINT is now compatible with the PREFILL_CONFIG section, unlike for // the generate model they're not mutually exclusive const ::intel_npu::npuw::llm::PrefillHint prefill_hint = m_cfg.get<::intel_npu::NPUW_LLM_PREFILL_HINT>(); @@ -1879,16 +1910,19 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr& m merge_config_with(generate_config, dyn_attn_opts); } - // Auto-detect MoE model by scanning for router/expert nodes - const bool is_moe = is_moe_model(kvcache_model); if (is_moe) { - // Apply MoE optimizations for prefill stage + // Apply MoE configuration for prefill stage const auto prefill_moe_hint = m_cfg.get<::intel_npu::NPUW_LLM_PREFILL_MOE_HINT>(); - apply_moe_optimizations(prefill_config, prefill_moe_hint, "PREFILL"); + apply_moe_config(prefill_config, prefill_moe_hint, "PREFILL"); - // Apply MoE optimizations for generate stage + // Apply MoE configuration for generate stage const auto generate_moe_hint = m_cfg.get<::intel_npu::NPUW_LLM_GENERATE_MOE_HINT>(); - apply_moe_optimizations(generate_config, generate_moe_hint, "GENERATE"); + apply_moe_config(generate_config, generate_moe_hint, "GENERATE"); + + // Apply model transformations only to GENERATE stage (PREFILL doesn't support DEVICE_ROUTED transformations) + if (generate_moe_hint == ::intel_npu::npuw::llm::MoEHint::DEVICE_ROUTED) { + apply_moe_device_routed_transforms(generate_model_variants); + } } // Note: with dynamic attention in EITHER STAGE, we have to diff --git a/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/device_routed_moe_transform.cpp b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/device_routed_moe_transform.cpp new file mode 100644 index 00000000000000..3f467f914d7d65 --- /dev/null +++ b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/device_routed_moe_transform.cpp @@ -0,0 +1,564 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "device_routed_moe_transform.hpp" + +#include "../logging.hpp" +#include "../partitioning/patterns/moe.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/ops.hpp" + +namespace ov { +namespace npuw { +namespace pass { + +namespace opp = ov::pass::pattern; + +namespace { + +// ============================================================================ +// Helper structures for organizing transformation data +// ============================================================================ + +struct LayerNodes { + std::vector> tiles; + std::vector> constant_reshapes; + std::vector> dynamic_reshapes; + std::vector> matmuls; + std::vector> adds; + std::vector> multiplies; + std::shared_ptr transpose; + size_t num_experts = 0; + + bool has_required_nodes() const { + return (!constant_reshapes.empty() || !dynamic_reshapes.empty()) && (!matmuls.empty() || !adds.empty()); + } +}; + +struct RouterInfo { + std::shared_ptr topk_node; + ov::Output topk_indices_raw; + ov::Output topk_softmax_scores; + int64_t k_value; + std::string layer_id; +}; + +// ============================================================================ +// Helper functions +// ============================================================================ + +// Trace back through Convert to find actual weight/bias source +inline ov::Output get_weight_source(const ov::Output& input) { + auto node = input.get_node_shared_ptr(); + if (auto convert = std::dynamic_pointer_cast(node)) { + return convert->input_value(0); + } + return input; +} + +// Extract layer ID from node name (e.g., "layers.0." from full name) +std::string extract_layer_id(const std::string& topk_name) { + size_t layers_pos = topk_name.find("layers."); + if (layers_pos == std::string::npos) { + return ""; + } + + size_t start = layers_pos; + size_t end = topk_name.find(".", start + 7); + if (end == std::string::npos) { + end = topk_name.find("/", start); + } + + return (end != std::string::npos) ? topk_name.substr(start, end - start + 1) : ""; +} + +// Check if node name belongs to the specified layer +inline bool belongs_to_layer(const std::string& node_name, const std::string& layer_id) { + return node_name.find(layer_id) != std::string::npos; +} + +// Check if a reshape operation is unsqueeze-like (only inserts dimensions with size 1) +bool is_unsqueeze_like_reshape(const std::shared_ptr& reshape) { + auto input_shape = reshape->input_value(0).get_partial_shape(); + auto output_shape = reshape->get_output_partial_shape(0); + + // Must have static ranks and output rank must be greater than input rank + if (!input_shape.rank().is_static() || !output_shape.rank().is_static() || + output_shape.rank().get_length() <= input_shape.rank().get_length()) { + return false; + } + + // Verify all output dims are either from input or newly inserted with size 1 + int64_t in_idx = 0; + for (int64_t out_idx = 0; out_idx < output_shape.rank().get_length(); ++out_idx) { + if (in_idx < input_shape.rank().get_length() && + (!input_shape[in_idx].is_static() || !output_shape[out_idx].is_static() || + input_shape[in_idx].get_length() == output_shape[out_idx].get_length())) { + // This dimension matches input dimension + ++in_idx; + } else if (!output_shape[out_idx].is_static() || output_shape[out_idx].get_length() != 1) { + // Not a size-1 inserted dimension - not unsqueeze-like + return false; + } + // else: this is a newly inserted dimension with size 1, continue + } + + // Valid unsqueeze if all input dimensions were matched + return in_idx == input_shape.rank().get_length(); +} + +// ============================================================================ +// Router processing +// ============================================================================ + +std::optional process_router_topk(const std::shared_ptr& topk_node) { + if (!topk_node || topk_node->get_mode() != ov::op::v11::TopK::Mode::MAX) { + return std::nullopt; + } + + std::string topk_name = topk_node->get_friendly_name(); + if (topk_name.find(ov::npuw::patterns::moe::MLP_ROUTER_NAME) == std::string::npos) { + return std::nullopt; + } + + // Validate TopK indices shape (batch dimension should be 1, indicates it is model for decoding) + auto topk_indices_raw = topk_node->output(1); + auto indices_shape = topk_indices_raw.get_partial_shape(); + if (indices_shape.rank().is_static() && indices_shape.rank().get_length() == 2) { + if (indices_shape[0].is_static() && indices_shape[0].get_length() != 1) { + LOG_WARN(" TopK indices batch dimension is not 1, skipping"); + return std::nullopt; + } + } + + // Extract K value + auto k_input = topk_node->input_value(1); + auto k_const = std::dynamic_pointer_cast(k_input.get_node_shared_ptr()); + if (!k_const) { + LOG_WARN(" TopK K value is not a constant, skipping"); + return std::nullopt; + } + int64_t k_value = k_const->cast_vector()[0]; + + // Extract layer ID + std::string layer_id = extract_layer_id(topk_name); + if (layer_id.empty()) { + LOG_WARN(" Cannot extract layer ID from: " << topk_name); + return std::nullopt; + } + + // Find Softmax for router scores + auto topk_values = topk_node->output(0); + std::shared_ptr topk_softmax = nullptr; + for (const auto& target : topk_values.get_target_inputs()) { + auto consumer = target.get_node()->shared_from_this(); + if (auto softmax = std::dynamic_pointer_cast(consumer)) { + topk_softmax = softmax; + break; + } + } + + if (!topk_softmax) { + LOG_WARN(" No Softmax found for TopK values"); + return std::nullopt; + } + + LOG_INFO("DeviceRoutedMoE: Processing router TopK: " << topk_name << " (K=" << k_value << ")"); + + return RouterInfo{topk_node, topk_indices_raw, topk_softmax->output(0), k_value, layer_id}; +} + +// ============================================================================ +// Node collection per layer +// ============================================================================ + +LayerNodes collect_layer_nodes(const std::shared_ptr& model, const RouterInfo& router) { + LayerNodes nodes; + const std::string& layer_id = router.layer_id; + int64_t k_value = router.k_value; + + // Single pass through all nodes to collect relevant operations + for (const auto& n : model->get_ordered_ops()) { + std::string node_name = n->get_friendly_name(); + + // Skip nodes not belonging to this layer or not MoE expert nodes + if (node_name.find(ov::npuw::patterns::moe::MLP_EXPERT_NAME) == std::string::npos || + !belongs_to_layer(node_name, layer_id)) { + continue; + } + + // Collect Tile nodes + if (auto tile = std::dynamic_pointer_cast(n)) { + auto repeats_const = + std::dynamic_pointer_cast(tile->input_value(1).get_node_shared_ptr()); + if (repeats_const) { + auto repeats_data = repeats_const->cast_vector(); + if (!repeats_data.empty() && repeats_data[0] > k_value) { + if (nodes.num_experts == 0) { + nodes.num_experts = static_cast(repeats_data[0]); + } + nodes.tiles.push_back(tile); + } + } + continue; + } + + // Collect Reshape nodes + if (auto reshape = std::dynamic_pointer_cast(n)) { + auto shape_const = + std::dynamic_pointer_cast(reshape->input_value(1).get_node_shared_ptr()); + + if (!shape_const) { + // Dynamic reshape - check if it's unsqueeze-like + if (is_unsqueeze_like_reshape(reshape)) { + nodes.dynamic_reshapes.push_back(reshape); + } + } else { + // Constant reshape - check if dim 0 is expert dimension + auto shape_data = shape_const->cast_vector(); + if (nodes.num_experts > 0 && !shape_data.empty() && + shape_data[0] == static_cast(nodes.num_experts)) { + nodes.constant_reshapes.push_back(reshape); + } + } + continue; + } + + // Collect MatMul nodes + if (auto matmul = std::dynamic_pointer_cast(n)) { + auto weight_source = get_weight_source(matmul->input_value(1)); + auto weight_node = weight_source.get_node_shared_ptr(); + + // Check if quantized weight comes from Multiply with expert-dimension constant + if (auto multiply = std::dynamic_pointer_cast(weight_node)) { + for (size_t i = 0; i < 2; ++i) { + auto mul_input = multiply->get_input_node_shared_ptr(i); + if (auto const_node = std::dynamic_pointer_cast(mul_input)) { + auto shape = const_node->get_shape(); + if (nodes.num_experts > 0 && shape.size() >= 2 && shape[0] == nodes.num_experts) { + nodes.matmuls.push_back(matmul); + break; + } + } + } + } + continue; + } + + // Collect Add nodes + if (auto add = std::dynamic_pointer_cast(n)) { + // Check both inputs for expert-dimension bias constant + for (size_t input_idx = 0; input_idx < 2; ++input_idx) { + auto bias_source = get_weight_source(add->input_value(input_idx)); + auto const_node = std::dynamic_pointer_cast(bias_source.get_node_shared_ptr()); + + if (const_node) { + auto shape = const_node->get_shape(); + if (nodes.num_experts > 0 && shape.size() >= 1 && shape[0] == nodes.num_experts) { + nodes.adds.push_back(add); + break; + } + } + } + continue; + } + + // Collect Multiply nodes, e.g. AWQ multiply (one input from constant, other input is not constant/convert, and + // user is not MatMul) + if (auto multiply = std::dynamic_pointer_cast(n)) { + // Skip if this Multiply is used by MatMul (it's MatMul weights multiply, which has been processed by + // transform_matmuls) + bool used_by_matmul = false; + for (const auto& output : multiply->outputs()) { + for (const auto& target : output.get_target_inputs()) { + auto user = target.get_node()->shared_from_this(); + if (std::dynamic_pointer_cast(user)) { + used_by_matmul = true; + break; + } + } + if (used_by_matmul) + break; + } + if (used_by_matmul) { + continue; + } + + // Check both inputs: one should be constant with expert dimension, other should not be constant/convert + for (size_t const_idx = 0; const_idx < 2; ++const_idx) { + size_t other_idx = 1 - const_idx; + + auto const_source = get_weight_source(multiply->input_value(const_idx)); + auto const_node = std::dynamic_pointer_cast(const_source.get_node_shared_ptr()); + + if (const_node) { + auto shape = const_node->get_shape(); + if (nodes.num_experts > 0 && shape.size() >= 1 && shape[0] == nodes.num_experts) { + // Check if other input is not constant/convert + auto other_input = multiply->input_value(other_idx); + auto other_source = get_weight_source(other_input); + auto other_node = other_source.get_node_shared_ptr(); + + if (!std::dynamic_pointer_cast(other_node)) { + nodes.multiplies.push_back(multiply); + break; + } + } + } + } + continue; + } + + // Collect Transpose node + if (auto transpose = std::dynamic_pointer_cast(n)) { + auto input_node = transpose->input_value(0).get_node_shared_ptr(); + if (std::dynamic_pointer_cast(input_node) || + std::dynamic_pointer_cast(input_node)) { + nodes.transpose = transpose; + // Don't break - continue collecting other nodes + } + continue; + } + } + + return nodes; +} + +// ============================================================================ +// Node transformation +// ============================================================================ + +void transform_tiles(LayerNodes& nodes, int64_t k_value) { + for (auto& tile : nodes.tiles) { + auto repeats_const = + std::dynamic_pointer_cast(tile->input_value(1).get_node_shared_ptr()); + auto repeats_data = repeats_const->cast_vector(); + repeats_data[0] = k_value; + + auto new_repeats = + ov::op::v0::Constant::create(repeats_const->get_element_type(), repeats_const->get_shape(), repeats_data); + + tile->input(1).replace_source_output(new_repeats); + ov::copy_runtime_info(repeats_const, new_repeats); + } +} + +void transform_constant_reshapes(LayerNodes& nodes, int64_t k_value) { + for (auto& reshape : nodes.constant_reshapes) { + auto shape_const = + std::dynamic_pointer_cast(reshape->input_value(1).get_node_shared_ptr()); + auto shape_data = shape_const->cast_vector(); + + // Replace dim 0 (expert dimension) with K + shape_data[0] = k_value; + + auto new_shape = + ov::op::v0::Constant::create(shape_const->get_element_type(), shape_const->get_shape(), shape_data); + + reshape->input(1).replace_source_output(new_shape); + ov::copy_runtime_info(shape_const, new_shape); + } +} + +void transform_dynamic_reshapes(LayerNodes& nodes) { + // Replace dynamic reshapes with Unsqueeze at dimension 1 + for (auto& reshape : nodes.dynamic_reshapes) { + auto data_input = reshape->input_value(0); + auto unsqueeze_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto unsqueeze = std::make_shared(data_input, unsqueeze_axis); + unsqueeze->set_friendly_name(reshape->get_friendly_name() + "/unsqueeze_dim1"); + + ov::replace_node(reshape, unsqueeze); + ov::copy_runtime_info(reshape, unsqueeze); + } +} + +void transform_matmuls(LayerNodes& nodes, const std::shared_ptr& topk_indices) { + for (auto& matmul : nodes.matmuls) { + auto weight_input = matmul->input_value(1); + auto weight_source = get_weight_source(weight_input); + auto weight_node = weight_source.get_node_shared_ptr(); + + bool transformed = false; + // Handle Multiply case: insert Gather before Multiply on expert-dimension inputs + if (auto multiply = std::dynamic_pointer_cast(weight_node)) { + for (size_t i = 0; i < 2; ++i) { + auto mul_input = multiply->input_value(i); + auto mul_source = get_weight_source(mul_input); + auto mul_source_node = mul_source.get_node_shared_ptr(); + + if (auto const_node = std::dynamic_pointer_cast(mul_source_node)) { + auto shape = const_node->get_shape(); + if (shape.size() >= 2 && shape[0] == nodes.num_experts) { + // Insert Gather on constant source + auto gather_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {0}); + auto gathered = std::make_shared(mul_source, topk_indices, gather_axis); + gathered->set_friendly_name(mul_source_node->get_friendly_name() + "/gathered"); + + // Recreate Convert if present + auto mul_input_node = mul_input.get_node_shared_ptr(); + if (auto convert = std::dynamic_pointer_cast(mul_input_node)) { + auto new_convert = + std::make_shared(gathered, convert->get_destination_type()); + new_convert->set_friendly_name(convert->get_friendly_name() + "/regathered"); + multiply->input(i).replace_source_output(new_convert); + ov::copy_runtime_info({mul_source_node, convert}, {gathered, new_convert}); + } else { + multiply->input(i).replace_source_output(gathered); + ov::copy_runtime_info(mul_source_node, gathered); + } + transformed = true; + } + } + } + } else { + // Direct weight case: insert Gather on weight input + auto gather_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {0}); + auto gathered = std::make_shared(weight_input, topk_indices, gather_axis); + gathered->set_friendly_name(weight_input.get_node()->get_friendly_name() + "/gathered"); + + matmul->input(1).replace_source_output(gathered); + ov::copy_runtime_info(weight_input.get_node_shared_ptr(), gathered); + transformed = true; + } + OPENVINO_ASSERT(transformed, "Failed to transform MatMul weights for node: ", matmul->get_friendly_name()); + } +} + +void transform_adds(LayerNodes& nodes, const std::shared_ptr& topk_indices) { + for (auto& add : nodes.adds) { + bool transformed = false; + for (size_t input_idx = 0; input_idx < 2; ++input_idx) { + auto bias_input = add->input_value(input_idx); + auto bias_source = get_weight_source(bias_input); + auto const_node = std::dynamic_pointer_cast(bias_source.get_node_shared_ptr()); + + if (const_node) { + auto shape = const_node->get_shape(); + if (nodes.num_experts > 0 && shape.size() >= 1 && shape[0] == nodes.num_experts) { + auto gather_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {0}); + auto gathered = std::make_shared(bias_input, topk_indices, gather_axis); + gathered->set_friendly_name(bias_input.get_node()->get_friendly_name() + "/gathered"); + + add->input(input_idx).replace_source_output(gathered); + ov::copy_runtime_info(bias_input.get_node_shared_ptr(), gathered); + + transformed = true; + break; + } + } + } + OPENVINO_ASSERT(transformed, "Failed to transform Add biases for node: ", add->get_friendly_name()); + } +} + +void transform_multiplies(LayerNodes& nodes, const std::shared_ptr& topk_indices) { + for (auto& multiply : nodes.multiplies) { + bool transformed = false; + for (size_t input_idx = 0; input_idx < 2; ++input_idx) { + auto const_input = multiply->input_value(input_idx); + auto const_source = get_weight_source(const_input); + auto const_node = std::dynamic_pointer_cast(const_source.get_node_shared_ptr()); + + if (const_node) { + auto shape = const_node->get_shape(); + if (nodes.num_experts > 0 && shape.size() >= 1 && shape[0] == nodes.num_experts) { + auto gather_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {0}); + auto gathered = std::make_shared(const_input, topk_indices, gather_axis); + gathered->set_friendly_name(const_input.get_node()->get_friendly_name() + "/gathered"); + + multiply->input(input_idx).replace_source_output(gathered); + ov::copy_runtime_info(const_input.get_node_shared_ptr(), gathered); + + transformed = true; + break; + } + } + } + OPENVINO_ASSERT(transformed, + "Failed to transform Multiply constants for node: ", + multiply->get_friendly_name()); + } +} + +void transform_transpose(LayerNodes& nodes, const ov::Output& topk_softmax_scores) { + if (nodes.transpose) { + auto transpose_input = nodes.transpose->input_value(0); + auto input_node = transpose_input.get_node_shared_ptr(); + + nodes.transpose->input(0).replace_source_output(topk_softmax_scores); + ov::copy_runtime_info(input_node, topk_softmax_scores.get_node_shared_ptr()); + } +} + +bool apply_layer_transformation(const RouterInfo& router, LayerNodes& nodes) { + // Validate we have required nodes + if (!nodes.has_required_nodes()) { + if (nodes.constant_reshapes.empty() && nodes.dynamic_reshapes.empty()) { + LOG_WARN(" Skipping layer " << router.layer_id << ": No Reshape nodes found"); + } else { + LOG_WARN(" Skipping layer " << router.layer_id << ": No MatMul/Add nodes found"); + } + return false; + } + + // Create reshaped TopK indices [1, K] -> [K] + auto new_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {router.k_value}); + auto topk_indices = std::make_shared(router.topk_indices_raw, new_shape, false); + topk_indices->set_friendly_name(router.topk_node->get_friendly_name() + "/indices_reshaped"); + + // Apply all transformations + transform_tiles(nodes, router.k_value); + transform_constant_reshapes(nodes, router.k_value); + transform_dynamic_reshapes(nodes); + transform_matmuls(nodes, topk_indices); + transform_adds(nodes, topk_indices); + transform_multiplies(nodes, topk_indices); + transform_transpose(nodes, router.topk_softmax_scores); + + LOG_INFO("DeviceRoutedMoE transformation successful for " << router.layer_id); + LOG_INFO(" Tiles: " << nodes.tiles.size() << ", ConstReshapes: " << nodes.constant_reshapes.size() + << ", DynReshapes: " << nodes.dynamic_reshapes.size() << ", MatMuls: " << nodes.matmuls.size() + << ", Adds: " << nodes.adds.size() << ", Multiplies: " << nodes.multiplies.size() + << ", K=" << router.k_value); + + return true; +} + +} // anonymous namespace + +// ============================================================================ +// Main transformation entry point +// ============================================================================ + +bool DeviceRoutedMoETransform::run_on_model(const std::shared_ptr& model) { + LOG_DEBUG("DeviceRoutedMoETransform: Starting transformation"); + + bool model_changed = false; + + // Process each Router TopK node (one per MoE layer) + for (const auto& node : model->get_ordered_ops()) { + auto topk_node = std::dynamic_pointer_cast(node); + + // Step 1: Process and validate router + auto router = process_router_topk(topk_node); + if (!router.has_value()) { + continue; + } + + // Step 2: Collect all nodes for this layer + auto layer_nodes = collect_layer_nodes(model, router.value()); + + // Step 3: Transform collected nodes (all-or-nothing) + if (apply_layer_transformation(router.value(), layer_nodes)) { + model_changed = true; + } + } + + return model_changed; +} + +} // namespace pass +} // namespace npuw +} // namespace ov diff --git a/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/device_routed_moe_transform.hpp b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/device_routed_moe_transform.hpp new file mode 100644 index 00000000000000..11a7eda9980ff8 --- /dev/null +++ b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/device_routed_moe_transform.hpp @@ -0,0 +1,91 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +/** + * @file device_routed_moe_transform.hpp + * @brief Device-routed MoE transformation using Gather-based expert selection + * + * This transformation implements DEVICE_ROUTED mode for MoE models, where expert + * selection is performed dynamically on the device using Gather operations driven + * by Router's TopK outputs, avoiding graph splitting and reducing host-device overhead. + * + * Key Features: + * - Uses TopK indices from router to dynamically gather expert weights and biases + * - No NonZero/ScatterElementsUpdate (device-friendly operations) + * - Keeps full computation in-graph without splitting + * - Reduces host-device communication compared to HOST_ROUTED mode + * + * Transformation Strategy (Two-Phase Approach): + * Phase 1 - Collection: + * 1. Locate Router's TopK node (selecting K active experts per token) + * 2. Extract TopK indices and Softmax scores + * 3. Collect all expert nodes for the layer: + * - Tile nodes (expert dimension expansion) + * - Reshape nodes (constant or dynamic/unsqueeze-like) + * - MatMul nodes (expert computation with grouped weights) + * - Add nodes (expert biases) + * - Transpose nodes (routing score processing) + * + * Phase 2 - Transformation (all-or-nothing): + * 1. Update Tile repeat counts from num_experts to K + * 2. Update Reshape shapes to use K instead of num_experts + * 3. Replace dynamic reshapes with Unsqueeze operations + * 4. Insert Gather on expert weights/scales (for MatMul inputs) + * 5. Insert Gather on expert biases (for Add inputs) + * 6. Replace routing scores with TopK Softmax outputs + * + * Quantization Support: + * - Detects Multiply nodes in weight path (quantized_weight * scale) + * - Inserts Gather on both quantized weights and per-expert scales + * - Preserves Convert nodes for data type handling + */ + +#pragma once + +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace npuw { +namespace pass { + +/** + * @brief Transform batched MoE experts to use Gather-based dynamic expert selection + * + * Pattern to match: + * Router: + * Input → MatMul(router_weights) → Add(router_bias) → TopK(K=num_active_experts) + * TopK.output(0): top-K scores → Softmax → routing weights + * TopK.output(1): top-K indices → used for Gather operations + * + * Experts (batched execution for all num_experts): + * Tile(repeat=[num_experts, 1, ...]) → Reshape([num_experts, ...]) + * → MatMul(grouped_weights[num_experts, d1, d2]) → Add(grouped_bias[num_experts, d]) + * → ... expert computation ... + * → Multiply(expert_outputs × routing_scores) → ReduceSum + * + * Transformation: + * Router: + * TopK.output(1) → Reshape([K]) → used as Gather indices + * TopK.output(0) → Softmax → replaces ScatterElementsUpdate routing scores + * + * Experts (dynamic execution for K active experts): + * - Tile(repeat=[K, 1, ...]) // reduced from num_experts to K + * - Reshape([K, ...]) // updated shape + * - Gather(grouped_weights, topk_indices, axis=0) → [K, d1, d2] + * - Gather(grouped_bias, topk_indices, axis=0) → [K, d] + * - For quantized weights: Gather both weight and scale tensors + * - Dynamic reshapes replaced with Unsqueeze operations + * + * This transformation reduces computation from num_experts (e.g., 32) to K (e.g., 4) + * active experts per token, with expert selection performed on-device via Gather. + */ +class DeviceRoutedMoETransform : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("npuw::pass::DeviceRoutedMoETransform", "0"); + bool run_on_model(const std::shared_ptr& model) override; +}; + +} // namespace pass +} // namespace npuw +} // namespace ov diff --git a/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/gather_to_2d_gather.cpp b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/gather_to_2d_gather.cpp new file mode 100644 index 00000000000000..e920186e9c5793 --- /dev/null +++ b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/gather_to_2d_gather.cpp @@ -0,0 +1,203 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gather_to_2d_gather.hpp" + +#include +#include + +#include "../logging.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/ops.hpp" + +namespace ov { +namespace npuw { +namespace pass { + +namespace { + +// ============================================================================ +// Helper structures for organizing transformation data +// ============================================================================ + +struct GatherInfo { + std::shared_ptr gather_node; + int64_t N; // num_experts (data dim 0) + int64_t M; // feature_dim (data dim 1) + int64_t K; // hidden_dim (data dim 2) + int64_t I; // num_selected (indices size) +}; + +// ============================================================================ +// Helper functions +// ============================================================================ + +// Check if a Gather node is valid for 3D->2D transformation +std::optional validate_gather_for_transform(const std::shared_ptr& gather) { + if (!gather) { + return std::nullopt; + } + + // Get gather inputs + auto data_input = gather->input_value(0); + auto indices_input = gather->input_value(1); + auto axis_input = gather->input_value(2); + + // Check if axis is 0 (gathering on first dimension) + auto axis_const = std::dynamic_pointer_cast(axis_input.get_node_shared_ptr()); + if (!axis_const) { + return std::nullopt; + } + auto axis_value = axis_const->cast_vector()[0]; + if (axis_value != 0) { + return std::nullopt; + } + + // Check data shape: should be 3D [N, M, K] with static dimensions + auto data_shape = data_input.get_partial_shape(); + if (!data_shape.rank().is_static() || data_shape.rank().get_length() != 3) { + return std::nullopt; + } + if (!data_shape[0].is_static() || !data_shape[1].is_static() || !data_shape[2].is_static()) { + return std::nullopt; + } + + int64_t M = data_shape[1].get_length(); + int64_t K = data_shape[2].get_length(); + + // Only transform if both M and K are not 1 (otherwise transformation is not beneficial) + if (M == 1 || K == 1) { + return std::nullopt; + } + + // Check indices shape: should be 1D [I] with static dimension + auto indices_shape = indices_input.get_partial_shape(); + if (!indices_shape.rank().is_static() || indices_shape.rank().get_length() != 1) { + return std::nullopt; + } + if (!indices_shape[0].is_static()) { + return std::nullopt; + } + + // Valid gather - return info + return GatherInfo{gather, data_shape[0].get_length(), M, K, indices_shape[0].get_length()}; +} + +// Transform a single 3D Gather to 2D Gather sequence +void transform_gather_to_2d(const GatherInfo& info) { + auto gather = info.gather_node; + auto data_input = gather->input_value(0); + auto indices_input = gather->input_value(1); + + std::string gather_name = gather->get_friendly_name(); + + // Step 1: Reshape indices [I] -> [I, 1] + std::vector indices_reshape_data = {static_cast(info.I), 1}; + auto indices_reshape_shape = + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, indices_reshape_data.data()); + auto reshaped_indices = std::make_shared(indices_input, indices_reshape_shape, false); + reshaped_indices->set_friendly_name(gather_name + "/indices_reshaped"); + + // Step 2: Multiply by M to get expert starting positions [I, 1] + std::vector m_data = {static_cast(info.M)}; + auto m_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1, 1}, m_data.data()); + auto experts_start = std::make_shared(reshaped_indices, m_const); + experts_start->set_friendly_name(gather_name + "/experts_start"); + + // Step 3: Create range [0, 1, 2, ..., M-1] and tile to [I, M] + std::vector range_values(info.M); + std::iota(range_values.begin(), range_values.end(), 0); + auto range_m = + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1, static_cast(info.M)}, range_values); + // Mark this constant to be preserved in function body during partitioning + range_m->get_rt_info()["npuw_moe_gather_indices"] = true; + + // Tile range to [I, M] + std::vector tile_repeats_data = {static_cast(info.I), 1}; + auto tile_repeats = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, tile_repeats_data.data()); + auto range_m_tiled = std::make_shared(range_m, tile_repeats); + range_m_tiled->set_friendly_name(gather_name + "/range_tiled"); + + // Step 4: Add experts_start + range to get final indices [I, M] + auto new_indices = std::make_shared(experts_start, range_m_tiled); + new_indices->set_friendly_name(gather_name + "/new_indices"); + + // Step 5: Flatten indices [I, M] -> [I*M] + std::vector flat_indices_shape_data = {static_cast(info.I * info.M)}; + auto flat_indices_shape = + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, flat_indices_shape_data.data()); + auto flat_indices = std::make_shared(new_indices, flat_indices_shape, false); + flat_indices->set_friendly_name(gather_name + "/flat_indices"); + + // Step 6: Flatten weights [N, M, K] -> [N*M, K] + std::vector flat_weights_shape_data = {static_cast(info.N * info.M), + static_cast(info.K)}; + auto flat_weights_shape = + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, flat_weights_shape_data.data()); + auto flat_weights = std::make_shared(data_input, flat_weights_shape, false); + flat_weights->set_friendly_name(gather_name + "/flat_weights"); + + // Step 7: Perform 2D Gather [I*M, K] + std::vector gather_axis_data = {0}; + auto gather_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, gather_axis_data.data()); + auto gathered_flat = std::make_shared(flat_weights, flat_indices, gather_axis); + gathered_flat->set_friendly_name(gather_name + "/gathered_flat"); + + // Step 8: Reshape to final output [I, M, K] + std::vector output_shape_data = {static_cast(info.I), + static_cast(info.M), + static_cast(info.K)}; + auto output_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, output_shape_data.data()); + auto final_output = std::make_shared(gathered_flat, output_shape, false); + final_output->set_friendly_name(gather_name + "/output"); + + // Replace the original Gather with the final Reshape + ov::replace_node(gather, final_output); + ov::copy_runtime_info(gather, + {reshaped_indices, + experts_start, + range_m_tiled, + new_indices, + flat_indices, + flat_weights, + gathered_flat, + final_output}); +} + +} // anonymous namespace + +// ============================================================================ +// Main transformation entry point +// ============================================================================ + +bool GatherTo2DGather::run_on_model(const std::shared_ptr& model) { + LOG_DEBUG("GatherTo2DGather: Starting transformation"); + + std::vector gathers_to_transform; + + // Collect and validate Gather nodes + for (const auto& node : model->get_ordered_ops()) { + auto gather = std::dynamic_pointer_cast(node); + auto gather_info = validate_gather_for_transform(gather); + + if (gather_info.has_value()) { + gathers_to_transform.push_back(gather_info.value()); + } + } + + // Transform each valid Gather + for (const auto& info : gathers_to_transform) { + transform_gather_to_2d(info); + } + + if (!gathers_to_transform.empty()) { + LOG_INFO("GatherTo2DGather: Transformed " << gathers_to_transform.size() << " Gather nodes"); + } + + return !gathers_to_transform.empty(); +} + +} // namespace pass +} // namespace npuw +} // namespace ov diff --git a/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/gather_to_2d_gather.hpp b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/gather_to_2d_gather.hpp new file mode 100644 index 00000000000000..1ae1c2b9a2da10 --- /dev/null +++ b/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/gather_to_2d_gather.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace npuw { +namespace pass { + +/** + * @brief Transform 3D Gather to 2D Gather sequence for hardware optimization + * + * This pass transforms: + * Gather(weights[N, M, K], indices[I]) -> output[I, M, K] + * + * Into equivalent sequence: + * 1. reshaped_indices = Reshape(indices[I]) -> [I, 1] + * 2. experts_start = Multiply(reshaped_indices, M) -> [I, 1] + * 3. range_m = Constant([0, 1, ..., M-1]) -> [1, M] + * range_m_tiled = Tile(range_m, [I, 1]) -> [I, M] + * 4. new_indices = Add(experts_start, range_m_tiled) -> [I, M] + * 5. flat_indices = Reshape(new_indices) -> [I*M] + * 6. flat_weights = Reshape(weights) -> [N*M, K] + * 7. gathered_flat = Gather(flat_weights, flat_indices, axis=0) -> [I*M, K] + * 8. output = Reshape(gathered_flat) -> [I, M, K] + * + * This transformation enables better hardware support for Gather operations + * by converting multi-dimensional gather into flattened 2D gather. + */ +class GatherTo2DGather : public ov::pass::ModelPass { +public: + OPENVINO_RTTI("GatherTo2DGather", "0"); + bool run_on_model(const std::shared_ptr& model) override; +}; + +} // namespace pass +} // namespace npuw +} // namespace ov diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp index 3af791d2d46ef1..854e292b5ca1e8 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/partitioning.cpp @@ -1411,6 +1411,12 @@ void Partitioner::saveRepeatedConstants(const std::string& func_name) { } return false; }; + // Helper to check if a constant is MoE Gather indices (marked by GatherTo2DGather pass) + auto is_moe_gather_const = [](const CTPtr& const_node) -> bool { + const auto& rt_info = const_node->get_rt_info(); + return rt_info.count("npuw_moe_gather_indices") > 0; + }; + auto check_and_mark = [&](const ov::npuw::RepeatedBlock::MatchedLayers& bank) { std::unordered_set instances; for (auto&& l : bank) { @@ -1422,19 +1428,30 @@ void Partitioner::saveRepeatedConstants(const std::string& func_name) { LOG_DEBUG("Checking a bank with prototype node " << proto_node << "..."); LOG_BLOCK(); - if (ov::npuw::partitioning::traits::is_tiny_shape(proto_shape) && - std::all_of(instances.begin(), instances.end(), [&](const CTPtr& other_node) -> bool { - return (other_node->output(0).get_shape() == proto_node->output(0).get_shape()) && - values_are_the_same(proto_node, other_node); - })) { - // Check passed for this group. - LOG_DEBUG("[KEEP] It is safe to keep this bank in function"); - for (auto&& const_node : instances) { - func_group.consts_to_keep.insert(const_node); - } - } else { - LOG_DEBUG("[CUT ] This group of Const ops will be cut-off from the function: " + bool is_tiny = ov::npuw::partitioning::traits::is_tiny_shape(proto_shape); + bool is_moe_gather = is_moe_gather_const(proto_node); + if (!is_tiny && !is_moe_gather) { + LOG_DEBUG("[CUT ] Not tiny shape and not MoE Gather indices - will be cut-off from the function"); + return; + } + + bool all_identical = std::all_of(instances.begin(), instances.end(), [&](const CTPtr& other_node) -> bool { + return (other_node->output(0).get_shape() == proto_node->output(0).get_shape()) && + values_are_the_same(proto_node, other_node); + }); + if (!all_identical) { + LOG_DEBUG("[CUT ] Values differ across instances - will be cut-off from the function: " << proto_node->get_friendly_name()); + return; + } + + if (is_moe_gather) { + LOG_DEBUG("[KEEP] MoE Gather indices constant - identical across all repeats"); + } else { + LOG_DEBUG("[KEEP] Tiny shape constant - safe to keep in function"); + } + for (auto&& const_node : instances) { + func_group.consts_to_keep.insert(const_node); } }; for (auto&& bank : rep_block.consts) { diff --git a/src/plugins/intel_npu/tests/unit/CMakeLists.txt b/src/plugins/intel_npu/tests/unit/CMakeLists.txt index 67d204adca0fbf..80496989ae5fac 100644 --- a/src/plugins/intel_npu/tests/unit/CMakeLists.txt +++ b/src/plugins/intel_npu/tests/unit/CMakeLists.txt @@ -45,6 +45,8 @@ ov_add_test_target( ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/npuw/host_flash_attention.cpp ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/moe_transformation.cpp ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/moe_unroll_patterns.cpp + ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/device_routed_moe_transform.cpp + ${OpenVINO_SOURCE_DIR}/src/plugins/intel_npu/src/plugin/npuw/moe_transformations/gather_to_2d_gather.cpp LINK_LIBRARIES ${MANDATORY_UNIT_TESTS_LIBS} LABELS diff --git a/src/plugins/intel_npu/tests/unit/npuw/device_routed_moe_transform_test.cpp b/src/plugins/intel_npu/tests/unit/npuw/device_routed_moe_transform_test.cpp new file mode 100644 index 00000000000000..6b254878186b71 --- /dev/null +++ b/src/plugins/intel_npu/tests/unit/npuw/device_routed_moe_transform_test.cpp @@ -0,0 +1,542 @@ +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "moe_transformations/device_routed_moe_transform.hpp" + +#include + +#include + +#include "openvino/op/ops.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/pass/serialize.hpp" + +/* + * Test suite for Device-Routed MoE Transformation + * + * Testing Strategy: + * - BasicTransformation: Verify Gather insertion in quantized weights and shape updates + * - MultiLayerMoE: Test independent transformation of multiple layers + * - AWQActivationMultiply: Test AWQ activation scaling support + */ + +// Uncomment to save debug XML files during test execution +// #define SAVE_TEST_MODELS + +namespace { + +using namespace ov; +using namespace ov::npuw::pass; + +// ============================================================================ +// Test Utilities +// ============================================================================ + +class DeviceRoutedMoETransformTest : public ::testing::Test { +protected: + void SetUp() override {} + void TearDown() override {} + + // Helper: Save model to XML for debugging + void save_model(const std::shared_ptr& model, const std::string& prefix) { +#ifdef SAVE_TEST_MODELS + std::string xml_path = prefix + ".xml"; + std::string bin_path = prefix + ".bin"; + ov::pass::Serialize serialize_pass(xml_path, bin_path); + serialize_pass.run_on_model(const_cast&>(model)); +#endif + } + + // Helper: Count nodes of specific type in model + template + size_t count_nodes(const std::shared_ptr& model) { + size_t count = 0; + for (const auto& node : model->get_ordered_ops()) { + if (std::dynamic_pointer_cast(node)) { + count++; + } + } + return count; + } + + // Helper: Find Gather nodes in model + std::vector> find_gather_nodes(const std::shared_ptr& model) { + std::vector> gathers; + for (const auto& node : model->get_ordered_ops()) { + if (auto gather = std::dynamic_pointer_cast(node)) { + gathers.push_back(gather); + } + } + return gathers; + } +}; + +// ============================================================================ +// Synthetic Graph Builders +// ============================================================================ + +// Create complete Router graph based on router.ir +// Accepts shared input Parameter, returns router scores for Expert multiply +std::shared_ptr create_router_graph(const std::shared_ptr& router_input, + int64_t k_value, + const std::string& layer_id, + size_t hidden_dim = 2880, + size_t num_experts = 32) { + // Router MatMul with quantized weights (INT4 -> FP16 -> Multiply -> FP32) + auto router_weights_int4 = op::v0::Constant::create(element::i4, + Shape{num_experts, hidden_dim}, + std::vector(num_experts * hidden_dim, 1)); + router_weights_int4->set_friendly_name(layer_id + "mlp.router.weight_int4"); + + auto router_weights_fp16 = std::make_shared(router_weights_int4, element::f16); + auto router_scale = + op::v0::Constant::create(element::f16, Shape{num_experts, 1}, std::vector(num_experts, 1.0f)); + auto router_weights_scaled = std::make_shared(router_weights_fp16, router_scale); + auto router_weights_fp32 = std::make_shared(router_weights_scaled, element::f32); + + auto router_matmul = std::make_shared(router_input, router_weights_fp32, false, true); + router_matmul->set_friendly_name("__module.model." + layer_id + "mlp.router/aten::linear/MatMul"); + + // Router Add (bias) + auto router_bias = + op::v0::Constant::create(element::f32, Shape{1, num_experts}, std::vector(num_experts, 0.0f)); + auto router_add = std::make_shared(router_matmul, router_bias); + router_add->set_friendly_name("__module.model." + layer_id + "mlp.router/aten::linear/Add"); + + // TopK: [1, num_experts] -> values [1, K], indices [1, K] + auto k_const = op::v0::Constant::create(element::i64, Shape{}, std::vector{k_value}); + auto topk = std::make_shared(router_add, + k_const, + -1, + op::v11::TopK::Mode::MAX, + op::v11::TopK::SortType::NONE); + topk->set_friendly_name("__module.model." + layer_id + "mlp.router/aten::topk/TopK"); + + // Softmax on TopK values + auto softmax = std::make_shared(topk->output(0), 1); + softmax->set_friendly_name("__module.model." + layer_id + "mlp.router/aten::softmax/Softmax"); + + // ScatterElementsUpdate: scatter softmax back to [1, num_experts] + auto zeros = op::v0::Constant::create(element::f32, Shape{1, num_experts}, std::vector(num_experts, 0.0f)); + auto indices_i32 = std::make_shared(topk->output(1), element::i32); + auto scatter_axis = op::v0::Constant::create(element::i64, Shape{}, std::vector{1}); + auto scatter = std::make_shared(zeros, indices_i32, softmax, scatter_axis); + scatter->set_friendly_name("__module.model." + layer_id + "mlp.router/aten::scatter_/ScatterElementsUpdate"); + + // Transpose: [1, num_experts] -> [num_experts, 1] + auto transpose_order = op::v0::Constant::create(element::i32, Shape{2}, std::vector{1, 0}); + auto transpose = std::make_shared(scatter, transpose_order); + transpose->set_friendly_name("__module.model." + layer_id + "mlp.experts/aten::transpose/Transpose"); + + // Reshape: [num_experts, 1] -> [num_experts, 1, 1] + auto reshape_shape = + op::v0::Constant::create(element::i64, Shape{3}, std::vector{static_cast(num_experts), 1, 1}); + auto reshape = std::make_shared(transpose, reshape_shape, false); + reshape->set_friendly_name("__module.model." + layer_id + "mlp.experts/aten::view/Reshape_2"); + + // Unsqueeze: [num_experts, 1, 1] -> [num_experts, 1, 1, 1] + auto unsqueeze_axis = op::v0::Constant::create(element::i64, Shape{}, std::vector{3}); + auto unsqueeze = std::make_shared(reshape, unsqueeze_axis); + unsqueeze->set_friendly_name("__module.model." + layer_id + "mlp.experts/aten::unsqueeze/Unsqueeze_2"); + + return unsqueeze; +} + +// Create complete MoE graph with Router + Expert (GPT-OSS pattern) +// Router and Expert share the same input Parameter +std::shared_ptr create_complete_moe_graph(size_t num_experts = 32, + int64_t k_value = 4, + size_t hidden_dim = 2880, + size_t token_count = 1, + const std::string& layer_id = "layers.0.", + bool with_awq_multiply = false) { + // 1. Create shared input Parameter for both Router and Expert + auto shared_input = std::make_shared(element::f32, Shape{token_count, hidden_dim}); + shared_input->set_friendly_name(layer_id + "input"); + + // 2. Create Router graph + auto router_scores_output = create_router_graph(shared_input, k_value, layer_id, hidden_dim, num_experts); + + // 3. Create Expert graph (GPT-OSS Expert pattern) + // Tile: [token_count, hidden_dim] -> [num_experts*token_count, hidden_dim] + auto repeats = + op::v0::Constant::create(element::i64, Shape{2}, std::vector{static_cast(num_experts), 1}); + auto tile = std::make_shared(shared_input, repeats); + tile->set_friendly_name("__module.model." + layer_id + "mlp.experts/Tile"); + + // Reshape to 3D: [num_experts, token_count, hidden_dim] + auto reshape_shape1 = op::v0::Constant::create(element::i64, + Shape{3}, + std::vector{static_cast(num_experts), + static_cast(token_count), + static_cast(hidden_dim)}); + auto reshape1 = std::make_shared(tile, reshape_shape1, false); + reshape1->set_friendly_name("__module.model." + layer_id + "mlp.experts/Reshape"); + + // First MatMul (gate + up) with quantized weights [num_experts, hidden_dim*2, hidden_dim] + auto weights_int4_1 = op::v0::Constant::create(element::i4, + Shape{num_experts, hidden_dim * 2, hidden_dim}, + std::vector(num_experts * hidden_dim * 2 * hidden_dim, 1)); + weights_int4_1->set_friendly_name(layer_id + "mlp.experts.gate_up.weight_int4"); + + auto weights_fp16_1 = std::make_shared(weights_int4_1, element::f16); + auto weights_scale_1 = op::v0::Constant::create(element::f16, + Shape{num_experts, hidden_dim * 2, 1}, + std::vector(num_experts * hidden_dim * 2, 1.0f)); + auto weights_scaled_1 = std::make_shared(weights_fp16_1, weights_scale_1); + auto weights_fp32_1 = std::make_shared(weights_scaled_1, element::f32); + + auto matmul1 = std::make_shared(reshape1, weights_fp32_1, false, true); + matmul1->set_friendly_name("__module.model." + layer_id + "mlp.experts/MatMul_gate_up"); + + auto biases1 = op::v0::Constant::create(element::f32, + Shape{num_experts, 1, hidden_dim * 2}, + std::vector(num_experts * hidden_dim * 2, 0.0f)); + auto add1 = std::make_shared(matmul1, biases1); + add1->set_friendly_name("__module.model." + layer_id + "mlp.experts/Add_gate_up"); + + // Dual branches: Activation branch (Slice -> Minimum -> Swish) and Gate branch (Slice -> Clamp -> Add) + // Activation branch + auto slice_start1 = op::v0::Constant::create(element::i64, Shape{1}, std::vector{0}); + auto slice_stop1 = + op::v0::Constant::create(element::i64, Shape{1}, std::vector{static_cast(hidden_dim)}); + auto slice_step1 = op::v0::Constant::create(element::i64, Shape{1}, std::vector{1}); + auto slice_axis1 = op::v0::Constant::create(element::i64, Shape{1}, std::vector{2}); + auto slice1 = std::make_shared(add1, slice_start1, slice_stop1, slice_step1, slice_axis1); + slice1->set_friendly_name("__module.model." + layer_id + "mlp.experts/Slice_activation"); + + auto minimum_const = op::v0::Constant::create(element::f32, Shape{1}, std::vector{20.0f}); + auto minimum = std::make_shared(slice1, minimum_const); + minimum->set_friendly_name("__module.model." + layer_id + "mlp.experts/Minimum"); + + auto swish_beta = op::v0::Constant::create(element::f32, Shape{}, std::vector{1.0f}); + auto swish = std::make_shared(minimum, swish_beta); + swish->set_friendly_name("__module.model." + layer_id + "mlp.experts/Swish"); + + // Optional AWQ activation multiply (after Swish) + std::shared_ptr activation_output = swish; + if (with_awq_multiply) { + auto awq_scale = op::v0::Constant::create(element::f32, + Shape{num_experts, 1, hidden_dim}, + std::vector(num_experts * hidden_dim, 1.0f)); + awq_scale->set_friendly_name(layer_id + "mlp.experts.awq_scale"); + + auto awq_multiply = std::make_shared(swish, awq_scale); + awq_multiply->set_friendly_name("__module.model." + layer_id + "mlp.experts/AWQMultiply"); + activation_output = awq_multiply; + } + + // Gate branch + auto slice_start2 = + op::v0::Constant::create(element::i64, Shape{1}, std::vector{static_cast(hidden_dim)}); + auto slice_stop2 = + op::v0::Constant::create(element::i64, Shape{1}, std::vector{static_cast(hidden_dim * 2)}); + auto slice_step2 = op::v0::Constant::create(element::i64, Shape{1}, std::vector{1}); + auto slice_axis2 = op::v0::Constant::create(element::i64, Shape{1}, std::vector{2}); + auto slice2 = std::make_shared(add1, slice_start2, slice_stop2, slice_step2, slice_axis2); + slice2->set_friendly_name("__module.model." + layer_id + "mlp.experts/Slice_gate"); + + auto clamp = std::make_shared(slice2, -20.0f, 20.0f); + clamp->set_friendly_name("__module.model." + layer_id + "mlp.experts/Clamp"); + + auto add2_const = op::v0::Constant::create(element::f32, Shape{1}, std::vector{0.0f}); + auto add2 = std::make_shared(clamp, add2_const); + add2->set_friendly_name("__module.model." + layer_id + "mlp.experts/Add_gate"); + + // Merge branches: Multiply (use activation_output which may include AWQ multiply) + auto multiply1 = std::make_shared(activation_output, add2); + multiply1->set_friendly_name("__module.model." + layer_id + "mlp.experts/Multiply_merge"); + + // Second MatMul (down projection) with quantized weights [num_experts, hidden_dim, hidden_dim] + auto weights_int4_2 = op::v0::Constant::create(element::i4, + Shape{num_experts, hidden_dim, hidden_dim}, + std::vector(num_experts * hidden_dim * hidden_dim, 1)); + weights_int4_2->set_friendly_name(layer_id + "mlp.experts.down.weight_int4"); + + auto weights_fp16_2 = std::make_shared(weights_int4_2, element::f16); + auto weights_scale_2 = op::v0::Constant::create(element::f16, + Shape{num_experts, hidden_dim, 1}, + std::vector(num_experts * hidden_dim, 1.0f)); + auto weights_scaled_2 = std::make_shared(weights_fp16_2, weights_scale_2); + auto weights_fp32_2 = std::make_shared(weights_scaled_2, element::f32); + + auto matmul2 = std::make_shared(multiply1, weights_fp32_2, false, true); + matmul2->set_friendly_name("__module.model." + layer_id + "mlp.experts/MatMul_down"); + + auto biases2 = op::v0::Constant::create(element::f32, + Shape{num_experts, 1, hidden_dim}, + std::vector(num_experts * hidden_dim, 0.0f)); + auto add3 = std::make_shared(matmul2, biases2); + add3->set_friendly_name("__module.model." + layer_id + "mlp.experts/Add_down"); + + // Output reshape + auto reshape_shape2 = op::v0::Constant::create(element::i64, + Shape{3}, + std::vector{static_cast(num_experts), + static_cast(token_count), + static_cast(hidden_dim)}); + auto reshape2 = std::make_shared(add3, reshape_shape2, false); + reshape2->set_friendly_name("__module.model." + layer_id + "mlp.experts/Reshape_out"); + + // Multiply with router scores + auto output_multiply = std::make_shared(reshape2, router_scores_output); + output_multiply->set_friendly_name("__module.model." + layer_id + "mlp.experts/aten::mul/Multiply"); + + // Result + auto result = std::make_shared(output_multiply); + result->set_friendly_name(layer_id + "output"); + + return std::make_shared(ResultVector{result}, ParameterVector{shared_input}); +} + +// ============================================================================ +// Unit Tests +// ============================================================================ + +// Test 1: Basic transformation - Verify Gather insertion in quantized weights and shape updates +TEST_F(DeviceRoutedMoETransformTest, BasicTransformation) { + constexpr size_t num_experts = 8; + constexpr int64_t k_value = 4; + constexpr size_t hidden_dim = 2880; + + auto model = create_complete_moe_graph(num_experts, k_value, hidden_dim, 1, "layers.0."); + save_model(model, "device_routed_moe_basic_before"); + + // Verify initial state + EXPECT_EQ(count_nodes(model), 0) << "Should have no Gather before transformation"; + EXPECT_EQ(count_nodes(model), 1) << "Should have 1 Tile node"; + + // Apply transformation + ov::pass::Manager manager; + manager.register_pass(); + manager.run_passes(model); + + // Validate + EXPECT_NO_THROW(model->validate_nodes_and_infer_types()); + + save_model(model, "device_routed_moe_basic_after"); + + // Verify Gather insertion: gate_up weights + gate_up scale + gate_up biases + down weights + down scale + down + // biases = 6 + auto gathers = find_gather_nodes(model); + EXPECT_EQ(gathers.size(), 6) << "Should have 6 Gather nodes after transformation"; + + // Verify Gather inserted in quantization chains (INT4->FP16->Gather->Multiply->FP32) + // Note: Gather output may go through Convert before reaching Multiply + size_t gathers_in_quant_chain = 0; + for (const auto& gather : gathers) { + // Check if Gather output feeds into Multiply (directly or through Convert) + auto target_inputs = gather->output(0).get_target_inputs(); + if (target_inputs.empty()) + continue; + + auto next_node = target_inputs.begin()->get_node()->shared_from_this(); + + // Direct path: Gather -> Multiply + if (std::dynamic_pointer_cast(next_node)) { + gathers_in_quant_chain++; + continue; + } + + // Path with Convert: Gather -> Convert -> Multiply + if (auto convert = std::dynamic_pointer_cast(next_node)) { + auto convert_targets = convert->output(0).get_target_inputs(); + if (!convert_targets.empty()) { + if (std::dynamic_pointer_cast( + convert_targets.begin()->get_node()->shared_from_this())) { + gathers_in_quant_chain++; + } + } + } + } + EXPECT_GE(gathers_in_quant_chain, 4) + << "Gather should be inserted in quantization chains for gate_up and down weights (weights + scales)"; + + // Verify Tile repeats updated from num_experts to k_value + for (const auto& node : model->get_ordered_ops()) { + if (auto tile = std::dynamic_pointer_cast(node)) { + auto repeats_const = + std::dynamic_pointer_cast(tile->input_value(1).get_node_shared_ptr()); + ASSERT_NE(repeats_const, nullptr); + auto repeats_data = repeats_const->cast_vector(); + EXPECT_EQ(repeats_data[0], k_value) << "Tile repeats should be updated to K=" << k_value; + } + } + + // Verify Reshape shapes updated + for (const auto& node : model->get_ordered_ops()) { + if (auto reshape = std::dynamic_pointer_cast(node)) { + auto shape_const = + std::dynamic_pointer_cast(reshape->input_value(1).get_node_shared_ptr()); + if (shape_const) { + auto shape_data = shape_const->cast_vector(); + if (shape_data.size() == 3 && shape_data[0] > 1) { + EXPECT_EQ(shape_data[0], k_value) << "Reshape expert dimension should be updated to K=" << k_value; + } + } + } + } +} + +// Test 2: Multi-layer MoE +TEST_F(DeviceRoutedMoETransformTest, MultiLayerMoE) { + constexpr size_t num_experts = 8; + constexpr int64_t k_value_layer0 = 4; + constexpr int64_t k_value_layer1 = 2; + constexpr size_t hidden_dim = 2880; + + auto model_layer0 = create_complete_moe_graph(num_experts, k_value_layer0, hidden_dim, 1, "layers.0."); + auto model_layer1 = create_complete_moe_graph(num_experts, k_value_layer1, hidden_dim, 1, "layers.1."); + + // Merge models + ov::ParameterVector all_params; + ov::ResultVector all_results; + for (const auto& param : model_layer0->get_parameters()) + all_params.push_back(param); + for (const auto& param : model_layer1->get_parameters()) + all_params.push_back(param); + for (const auto& result : model_layer0->get_results()) + all_results.push_back(result); + for (const auto& result : model_layer1->get_results()) + all_results.push_back(result); + + auto merged_model = std::make_shared(all_results, all_params); + save_model(merged_model, "device_routed_moe_multi_layer_before"); + + // Apply transformation + ov::pass::Manager manager; + manager.register_pass(); + manager.run_passes(merged_model); + + EXPECT_NO_THROW(merged_model->validate_nodes_and_infer_types()); + + save_model(merged_model, "device_routed_moe_multi_layer_after"); + + // Verify both layers transformed independently + // Each layer: 6 Gathers (gate_up weights + scale + biases + down weights + scale + biases) = 12 total + auto gathers = find_gather_nodes(merged_model); + EXPECT_EQ(gathers.size(), 12) << "Should have 12 Gather nodes (6 per layer)"; + + // Verify Tile repeats per layer + size_t layer0_tiles = 0, layer1_tiles = 0; + for (const auto& node : merged_model->get_ordered_ops()) { + if (auto tile = std::dynamic_pointer_cast(node)) { + std::string name = tile->get_friendly_name(); + auto repeats_const = + std::dynamic_pointer_cast(tile->input_value(1).get_node_shared_ptr()); + if (repeats_const) { + auto repeats_data = repeats_const->cast_vector(); + if (name.find("layers.0.") != std::string::npos) { + EXPECT_EQ(repeats_data[0], k_value_layer0); + layer0_tiles++; + } else if (name.find("layers.1.") != std::string::npos) { + EXPECT_EQ(repeats_data[0], k_value_layer1); + layer1_tiles++; + } + } + } + } + EXPECT_EQ(layer0_tiles, 1); + EXPECT_EQ(layer1_tiles, 1); +} + +// Test 3: AWQ activation multiply support +TEST_F(DeviceRoutedMoETransformTest, AWQActivationMultiply) { + constexpr size_t num_experts = 8; + constexpr int64_t k_value = 4; + constexpr size_t hidden_dim = 2880; + + auto model = create_complete_moe_graph(num_experts, k_value, hidden_dim, 1, "layers.0.", true); + save_model(model, "device_routed_moe_awq_before"); + + // Apply transformation + ov::pass::Manager manager; + manager.register_pass(); + manager.run_passes(model); + + EXPECT_NO_THROW(model->validate_nodes_and_infer_types()); + + save_model(model, "device_routed_moe_awq_after"); + + // Verify AWQ multiply node exists + bool found_awq_multiply = false; + for (const auto& node : model->get_ordered_ops()) { + if (auto mult = std::dynamic_pointer_cast(node)) { + if (mult->get_friendly_name().find("AWQMultiply") != std::string::npos) { + found_awq_multiply = true; + + // Verify one input is from Swish + bool has_swish_input = false; + for (size_t i = 0; i < 2; ++i) { + auto input = mult->input_value(i).get_node_shared_ptr(); + if (std::dynamic_pointer_cast(input)) { + has_swish_input = true; + break; + } + } + EXPECT_TRUE(has_swish_input) << "AWQ Multiply should have Swish as input"; + + // Verify other input is Gather (AWQ scale after transformation) + bool has_gather_input = false; + for (size_t i = 0; i < 2; ++i) { + auto input = mult->input_value(i).get_node_shared_ptr(); + if (auto gather = std::dynamic_pointer_cast(input)) { + auto gather_shape = gather->get_output_shape(0); + EXPECT_EQ(gather_shape.size(), 3); + EXPECT_EQ(gather_shape[0], k_value) << "AWQ scale expert dimension should be K after Gather"; + has_gather_input = true; + break; + } + } + EXPECT_TRUE(has_gather_input) + << "AWQ Multiply should have Gather (AWQ scale) as input after transformation"; + break; + } + } + } + EXPECT_TRUE(found_awq_multiply) << "AWQ Multiply node should be present when enabled"; + + // Verify Gather inserted for AWQ scale + auto gathers = find_gather_nodes(model); + // gate_up weights + scale + biases + down weights + scale + biases + AWQ scale = 7 + EXPECT_EQ(gathers.size(), 7) << "Should have 7 Gather nodes (including AWQ scale)"; +} + +// Test 4: Negative test - No Softmax +TEST_F(DeviceRoutedMoETransformTest, NegativeNoSoftmax) { + constexpr size_t num_experts = 8; + constexpr int64_t k_value = 4; + constexpr size_t hidden_dim = 2880; + + auto input = std::make_shared(element::f32, Shape{1, hidden_dim}); + auto k_const = op::v0::Constant::create(element::i64, Shape{}, std::vector{k_value}); + auto topk = + std::make_shared(input, k_const, 1, op::v11::TopK::Mode::MAX, op::v11::TopK::SortType::NONE); + topk->set_friendly_name("__module.model.layers.0.mlp.router/aten::topk/TopK"); + + // NO Softmax - transformation should skip + + auto repeats = + op::v0::Constant::create(element::i64, Shape{2}, std::vector{static_cast(num_experts), 1}); + auto tile = std::make_shared(input, repeats); + tile->set_friendly_name("__module.model.layers.0.mlp.experts/Tile"); + + auto result = std::make_shared(tile); + auto model = std::make_shared(ResultVector{result}, ParameterVector{input}); + save_model(model, "device_routed_moe_negative_before"); + + // Apply transformation + ov::pass::Manager manager; + manager.register_pass(); + bool changed = manager.run_passes(model); + + // Should not transform + EXPECT_FALSE(changed) << "Transformation should skip when Softmax is missing"; + + auto repeats_const = std::dynamic_pointer_cast(tile->input_value(1).get_node_shared_ptr()); + auto repeats_data = repeats_const->cast_vector(); + EXPECT_EQ(repeats_data[0], num_experts) << "Tile repeats should remain unchanged"; +} + +} // namespace diff --git a/src/plugins/intel_npu/tests/unit/npuw/gather_to_2d_gather_test.cpp b/src/plugins/intel_npu/tests/unit/npuw/gather_to_2d_gather_test.cpp new file mode 100644 index 00000000000000..f32f96a3d418e8 --- /dev/null +++ b/src/plugins/intel_npu/tests/unit/npuw/gather_to_2d_gather_test.cpp @@ -0,0 +1,314 @@ +// Copyright (C) 2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "moe_transformations/gather_to_2d_gather.hpp" + +#include + +#include + +#include "openvino/op/ops.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/pass/serialize.hpp" + +/* + * Test suite for GatherTo2DGather Transformation + * + * Testing Strategy: + * - BasicTransformation: Verify 3D->2D Gather transformation and output shapes + * - NegativeAxis: Ensure transformation skips when axis != 0 + * - Negative2DData: Ensure transformation skips for 1D/2D data + * - NegativeSingleDimension: Ensure transformation skips when M=1 or K=1 + * - LargeDimensions: Stress test with realistic MoE sizes + */ + +// Uncomment to save debug XML files during test execution +// #define SAVE_TEST_MODELS + +namespace { + +using namespace ov; +using namespace ov::npuw::pass; + +// ============================================================================ +// Test Utilities +// ============================================================================ + +class GatherTo2DGatherTest : public ::testing::Test { +protected: + void SetUp() override {} + void TearDown() override {} + + // Helper: Save model to XML for debugging + void save_model(const std::shared_ptr& model, const std::string& prefix) { +#ifdef SAVE_TEST_MODELS + std::string xml_path = prefix + ".xml"; + std::string bin_path = prefix + ".bin"; + ov::pass::Serialize serialize_pass(xml_path, bin_path); + serialize_pass.run_on_model(const_cast&>(model)); +#endif + } + + // Helper: Count nodes of specific type in model + template + size_t count_nodes(const std::shared_ptr& model) { + size_t count = 0; + for (const auto& node : model->get_ordered_ops()) { + if (std::dynamic_pointer_cast(node)) { + count++; + } + } + return count; + } + + // Helper: Validate transformation results + void validate_transformation(const std::shared_ptr& model, int64_t I, int64_t M, int64_t K) { + // Verify node counts + EXPECT_EQ(count_nodes(model), 1) << "Should have 1 Gather after transformation"; + EXPECT_EQ(count_nodes(model), 1) << "Should have 1 Tile"; + EXPECT_EQ(count_nodes(model), 1) << "Should have 1 Multiply"; + EXPECT_EQ(count_nodes(model), 1) << "Should have 1 Add"; + + // Verify Multiply constant is M + bool found_multiply_constant = false; + for (const auto& node : model->get_ordered_ops()) { + if (auto multiply = std::dynamic_pointer_cast(node)) { + for (size_t i = 0; i < 2; ++i) { + auto input = multiply->input_value(i).get_node_shared_ptr(); + if (auto constant = std::dynamic_pointer_cast(input)) { + auto constant_data = constant->cast_vector(); + if (constant_data.size() == 1 && constant_data[0] == M) { + found_multiply_constant = true; + break; + } + } + } + } + } + EXPECT_TRUE(found_multiply_constant) << "Multiply should have M=" << M << " as constant"; + + // Verify Add constant is range [0, 1, 2, ..., M-1] + bool found_range_constant = false; + for (const auto& node : model->get_ordered_ops()) { + auto add = std::dynamic_pointer_cast(node); + if (!add) + continue; + + for (size_t i = 0; i < 2; ++i) { + auto input = add->input_value(i).get_node_shared_ptr(); + auto tile = std::dynamic_pointer_cast(input); + if (!tile) + continue; + + auto range_input = tile->input_value(0).get_node_shared_ptr(); + auto constant = std::dynamic_pointer_cast(range_input); + if (!constant) + continue; + + auto range_data = constant->cast_vector(); + if (range_data.size() != static_cast(M)) + continue; + + bool valid_range = true; + for (size_t j = 0; j < range_data.size(); ++j) { + if (range_data[j] != static_cast(j)) { + valid_range = false; + break; + } + } + + if (valid_range) { + found_range_constant = true; + break; + } + } + + if (found_range_constant) + break; + } + EXPECT_TRUE(found_range_constant) << "Add should have range [0, M-1] constant"; + + // Verify output shape from model results + auto results = model->get_results(); + ASSERT_EQ(results.size(), 1) << "Model should have 1 result"; + auto output_shape = results[0]->get_output_shape(0); + + ASSERT_EQ(output_shape.size(), 3) << "Output should be 3D"; + EXPECT_EQ(output_shape[0], I) << "First dimension should be I (num_selected)"; + EXPECT_EQ(output_shape[1], M) << "Second dimension should be M (feature_dim)"; + EXPECT_EQ(output_shape[2], K) << "Third dimension should be K (hidden_dim)"; + } +}; + +// ============================================================================ +// Synthetic Graph Builders +// ============================================================================ + +// Create a simple 3D Gather graph +// data: [N, M, K], indices: [I], axis: 0 -> output: [I, M, K] +std::shared_ptr create_3d_gather_graph(int64_t N, + int64_t M, + int64_t K, + int64_t I, + const std::string& name_prefix = "gather") { + // Data input [N, M, K] + auto data = op::v0::Constant::create(element::f32, + Shape{static_cast(N), static_cast(M), static_cast(K)}, + std::vector(N * M * K, 1.0f)); + data->set_friendly_name(name_prefix + "_data"); + + // Indices [I] + std::vector indices_data(I); + for (int64_t i = 0; i < I; ++i) { + indices_data[i] = i % N; // Valid indices within [0, N) + } + auto indices = op::v0::Constant::create(element::i64, Shape{static_cast(I)}, indices_data); + indices->set_friendly_name(name_prefix + "_indices"); + + // Axis = 0 + auto axis = op::v0::Constant::create(element::i64, Shape{}, std::vector{0}); + + // Gather + auto gather = std::make_shared(data, indices, axis); + gather->set_friendly_name(name_prefix); + + // Result + auto result = std::make_shared(gather); + result->set_friendly_name(name_prefix + "_output"); + + return std::make_shared(ResultVector{result}, ParameterVector{}); +} + +// Create Gather graph with non-zero axis +std::shared_ptr create_gather_with_axis(int64_t axis_value) { + auto data = op::v0::Constant::create(element::f32, Shape{8, 16, 32}, std::vector(8 * 16 * 32, 1.0f)); + auto indices = op::v0::Constant::create(element::i64, Shape{4}, std::vector{0, 1, 2, 3}); + auto axis = op::v0::Constant::create(element::i64, Shape{}, std::vector{axis_value}); + + auto gather = std::make_shared(data, indices, axis); + gather->set_friendly_name("gather_axis_" + std::to_string(axis_value)); + + auto result = std::make_shared(gather); + return std::make_shared(ResultVector{result}, ParameterVector{}); +} + +// Create Gather graph with 2D data +std::shared_ptr create_gather_with_2d_data() { + auto data = op::v0::Constant::create(element::f32, Shape{8, 32}, std::vector(8 * 32, 1.0f)); + auto indices = op::v0::Constant::create(element::i64, Shape{4}, std::vector{0, 1, 2, 3}); + auto axis = op::v0::Constant::create(element::i64, Shape{}, std::vector{0}); + + auto gather = std::make_shared(data, indices, axis); + gather->set_friendly_name("gather_2d"); + + auto result = std::make_shared(gather); + return std::make_shared(ResultVector{result}, ParameterVector{}); +} + +// Create Gather graph with M=1 (single feature dimension) +std::shared_ptr create_gather_with_single_m() { + auto data = op::v0::Constant::create(element::f32, Shape{8, 1, 32}, std::vector(8 * 1 * 32, 1.0f)); + auto indices = op::v0::Constant::create(element::i64, Shape{4}, std::vector{0, 1, 2, 3}); + auto axis = op::v0::Constant::create(element::i64, Shape{}, std::vector{0}); + + auto gather = std::make_shared(data, indices, axis); + gather->set_friendly_name("gather_m1"); + + auto result = std::make_shared(gather); + return std::make_shared(ResultVector{result}, ParameterVector{}); +} + +// ============================================================================ +// Unit Tests +// ============================================================================ + +// Test 1: Basic transformation - Verify 3D->2D Gather transformation +TEST_F(GatherTo2DGatherTest, BasicTransformation) { + constexpr int64_t N = 8; // num_experts + constexpr int64_t M = 16; // feature_dim + constexpr int64_t K = 32; // hidden_dim + constexpr int64_t I = 4; // num_selected + + auto model = create_3d_gather_graph(N, M, K, I); + save_model(model, "gather_to_2d_basic_before"); + + // Verify initial state + EXPECT_EQ(count_nodes(model), 1) << "Should have 1 Gather before transformation"; + EXPECT_EQ(count_nodes(model), 0) << "Should have no Reshape before transformation"; + EXPECT_EQ(count_nodes(model), 0) << "Should have no Tile before transformation"; + + // Apply transformation + ov::pass::Manager manager; + manager.register_pass(); + bool changed = manager.run_passes(model); + + EXPECT_TRUE(changed) << "Transformation should modify the graph"; + EXPECT_NO_THROW(model->validate_nodes_and_infer_types()); + + save_model(model, "gather_to_2d_basic_after"); + + // Validate transformation results + validate_transformation(model, I, M, K); +} + +// Test 2: Negative test - Axis != 0 +TEST_F(GatherTo2DGatherTest, NegativeAxis) { + auto model = create_gather_with_axis(1); // axis = 1 + save_model(model, "gather_to_2d_negative_axis_before"); + + ov::pass::Manager manager; + manager.register_pass(); + bool changed = manager.run_passes(model); + + EXPECT_FALSE(changed) << "Transformation should skip when axis != 0"; +} + +// Test 3: Negative test - 2D data +TEST_F(GatherTo2DGatherTest, Negative2DData) { + auto model = create_gather_with_2d_data(); + save_model(model, "gather_to_2d_negative_2d_before"); + + ov::pass::Manager manager; + manager.register_pass(); + bool changed = manager.run_passes(model); + + EXPECT_FALSE(changed) << "Transformation should skip for 2D data"; +} + +// Test 4: Negative test - M=1 (transformation not beneficial) +TEST_F(GatherTo2DGatherTest, NegativeSingleDimension) { + auto model = create_gather_with_single_m(); + save_model(model, "gather_to_2d_negative_m1_before"); + + ov::pass::Manager manager; + manager.register_pass(); + bool changed = manager.run_passes(model); + + EXPECT_FALSE(changed) << "Transformation should skip when M=1 (not beneficial)"; +} + +// Test 5: Large dimensions - Stress test with realistic MoE sizes +TEST_F(GatherTo2DGatherTest, LargeDimensions) { + constexpr int64_t N = 32; // 32 experts + constexpr int64_t M = 2880; // feature_dim (typical for large models) + constexpr int64_t K = 2880; // hidden_dim + constexpr int64_t I = 4; // top-4 routing + + auto model = create_3d_gather_graph(N, M, K, I); + save_model(model, "gather_to_2d_large_before"); + + ov::pass::Manager manager; + manager.register_pass(); + bool changed = manager.run_passes(model); + + EXPECT_TRUE(changed) << "Transformation should handle large dimensions"; + EXPECT_NO_THROW(model->validate_nodes_and_infer_types()); + + save_model(model, "gather_to_2d_large_after"); + + // Validate transformation results + validate_transformation(model, I, M, K); +} + +} // namespace