Skip to content

[GPU]qwen3 moe support #30448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 86 commits into
base: master
Choose a base branch
from

Conversation

riverlijunjie
Copy link
Contributor

@riverlijunjie riverlijunjie commented May 7, 2025

Support Qwen3 MoE model running with GPU plugin

Details:

  • Fuse moe subgraph into single moe_expert op to decrease total ops number and improve compile_model and inference performance.
  • moe_expert primitive execution stage:
    • First token adopts onednn gemm kernels pipeline and optimized opencl kernel(gatther, scatter) to do moe execution, each expert is executed in serial.
    • Second token adopts optimized opencl kernels(mlp_gate_up, mlp_down, softmax_topk, reduce) to do multiple-experts parallel execution.
  • Moe weight of each layer is allocated in a single usm memory and create submemory from it for each expert's weights/scale/zp memory, which is helpful for second token's expert kernels parallel execution.
  • Optimize key_cache and value_cache input.
  • Only support moe: u4 weight, f16 scale, u4 zp and group_size=128, which is required by qwen3 MoE 30B model.
  • Only support systolic gpu (A770/B580/ARL/LNL), doesn't support MTL, because first token need call onednn gemm kernel.

Moe fusion result

Original moe(contains 128 experts) exec graph:
image

With this PR, it will become one single moe_expert op:
image

TODO:

  • Support more MoE patterns, current only verify and support qwen3 moe pattern.
  • Integrate optimized cm kernel for second token moe
  • Align cm kernel to use the same scale/zp layout with opencl kernel.
  • Support more moe data type: u8 weight
  • Support other subgroup size: 32, 64,256...

Tickets:

luo-cheng2021 and others added 30 commits April 8, 2025 09:29
* Build Subgraph in parallel to improve compile_model performance

* SharedOpOptimization optimizes attribute visit

---------

Co-authored-by: Tingqian Li <[email protected]>
@itikhono
Copy link
Contributor

itikhono commented Jul 4, 2025

@CuriousPanCake could you take a look?

auto prev_moe = pattern_map.at(final_hidden_states).get_node_shared_ptr();
auto moe = ov::as_type_ptr<op::internal::MOE>(prev_moe);
OPENVINO_ASSERT(config == moe->get_config(), "each expert config must be same");
moe->add_consts(static_cast<size_t>(expert_no), consts);
Copy link
Contributor

@yeonbok yeonbok Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long time taken for this (add const & copy const)? If it is time consuming, there is room to concat thoese weights and then to be copied in parallel at GPU.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have ever done copying in parallel at GPU, but found there would be some race condition issue. Maybe put it as a TODO work?

: primitive_base(id, inputs, 1, {optional_data_type()}),
_config(config),
_mlp_params(param),
_mlp_weights_mem(wei_mem) {
Copy link
Contributor

@yeonbok yeonbok Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I ask why we need this special memory descriptor?
If we do like

  1. primitive: let the merged weights be a regular input node (data)
  2. At transform : weight1,2,3 => concat
  3. then at the post weight optimization phase in the gpu plugin transform, let them to be fused to one data node by gpu
  4. then the gpu moe primitive will have a single data
  5. then it will be saved and loaded in a normal path

Could you please let me know why the above does not work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wei data need be repacked for all experts of moe, which is not easy to be done it in transformation stage, we put it in CreateMOEOp(src/plugins/intel_gpu/src/plugin/ops/moe.cpp) and then pass it as mlp_weights_mem to primitive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM op implementation.

@rkazants rkazants self-requested a review July 14, 2025 10:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: docs OpenVINO documentation category: GPU OpenVINO GPU plugin category: IE Tests OpenVINO Test: plugins and common category: transformations OpenVINO Runtime library - Transformations
Projects
None yet
Development

Successfully merging this pull request may close these issues.