Skip to content

[megatron, model] feat: qwen3.5 example #5381

Merged
wuxibin89 merged 14 commits intoverl-project:mainfrom
ISEEKYAN:mcore_qwen35
Mar 13, 2026
Merged

[megatron, model] feat: qwen3.5 example #5381
wuxibin89 merged 14 commits intoverl-project:mainfrom
ISEEKYAN:mcore_qwen35

Conversation

@ISEEKYAN
Copy link
Collaborator

@ISEEKYAN ISEEKYAN commented Feb 24, 2026

What does this PR do?

thanks to @LiuXTao 's great work on ISEEKYAN/mbridge#83, the mbridge has supported qwen3.5.

This PR succeeded in running qwen3.5 SFT on verl based on mbridge supports for qwen3.5

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

see examples/sft/gsm8k/run_qwen3_5_megatron.sh
and examples/grpo_trainer/run_qwen3_5-35b-megatron.sh

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Qwen3.5 SFT with Megatron. The changes are mostly workarounds and fixes to support the Qwen3.5 architecture, particularly its Gated Delta Net (GDN) and chat template requirements. The changes look reasonable and well-commented, improving compatibility and robustness. I have one major concern about catching a broad Exception which could hide bugs.

return_tensors="pt",
**apply_chat_template_kwargs,
)
except (jinja2.exceptions.TemplateError, Exception) as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a generic Exception is risky as it can suppress unexpected errors, making debugging difficult. It's better to catch more specific exceptions. Since jinja2.exceptions.TemplateError is a subclass of Exception, the tuple (jinja2.exceptions.TemplateError, Exception) is redundant and equivalent to except Exception:. Please replace Exception with the specific exception type(s) that are expected to contain the 'No user query' message. If the exact type is unknown, consider catching a narrower set of exceptions like ValueError or TypeError which are common for such issues.

@wuyaoxuehun
Copy link

@ISEEKYAN does this pr can also support rl?

@ISEEKYAN
Copy link
Collaborator Author

@ISEEKYAN does this pr can also support rl?

just updated a script with RL supports. But it is not easy to prepare a right vllm dependency now🥲

@ISEEKYAN ISEEKYAN changed the title [megatron, model] feat: qwen3.5 megatron example of SFT [megatron, model] feat: qwen3.5 example Feb 26, 2026
@wuyaoxuehun
Copy link

@ISEEKYAN does this pr can also support rl?

just updated a script with RL supports. But it is not easy to prepare a right vllm dependency now🥲

Many thanks.
"I succeeded in running this script with the main branch of vllm on 20260225, yet there are still some minor issues

the vllm qwen3.5 during initialization, need to be fixed."

so what issue is there with vllm qwen3.5 initialization?

I see in vllm doc that vllm can indeed serve qwen3.5(https://docs.vllm.ai/projects/recipes/en/latest/Qwen/Qwen3.5.html).

@khazic
Copy link
Contributor

khazic commented Feb 27, 2026

Successfully ran Qwen3.5 SFT (verl megatron example) with the following setup:

(1) mbridge: install from source for qwen3_5 support —
pip install git+https://github.com/ISEEKYAN/mbridge.git
(PyPI 0.15.1 does not register qwen3_5; use git to get qwen3_5.)

(2) megatron-core == 0.16.0 — required for attention_output_gate and other GDN options.

(3) verl patch in verl/models/mcore/patch.py: applies the gate-slicing fix when
num_query_groups < tp_size (same as Megatron-LM PR #3529), plus the existing
mbridge compatibility patch. Without the gate patch, training fails with a
shape mismatch in _apply_output_gate.

Key library versions used:

  • mbridge: from git (reports 0.15.1)
  • megatron-core: 0.16.0
  • transformers: 5.2.0
  • torch: 2.9.0+cu129
  • flash_attn: 2.8.1
  • flash-linear-attention: 0.4.1

# Qwen3.5 uses Gated Delta Net (GDN) linear attention which currently does
# NOT support packed sequences (THD format) in Megatron-LM. Therefore:
# - actor.megatron.use_remove_padding=False (forces bshd compute format)
# - model.use_remove_padding=True (keeps NestedTensor in data pipeline)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For new model engine, I think we always use NestedTensor regardless of model.use_remove_padding?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree—this feature is definitely a must-have. From what I’ve seen in pure-text scenarios, given the same number of GPUs, the per-step latency is about 2 to 3 times that of Llama-Factory

# Try the fast path first: direct unbind works for some NestedTensor
# layouts where the batch dim is not entangled with the ragged dim.
try:
tensors = nt.unbind(dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case nested tensor unbind failed? I didn't expect that unbind may failed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3D jagged tensors (e.g., MRoPE position_ids)

@FlyingDutchman26
Copy link

It seems that the vllm nightly version which support Qwen3.5 require transformers==4.57.6, which conflicts with the latest transformers version 5.2.0

@ISEEKYAN
Copy link
Collaborator Author

ISEEKYAN commented Mar 3, 2026

It seems that the vllm nightly version which support Qwen3.5 require transformers==4.57.6, which conflicts with the latest transformers version 5.2.0

you could try with sglang

@zyfzjsc988
Copy link
Contributor

It seems that the vllm nightly version which support Qwen3.5 require transformers==4.57.6, which conflicts with the latest transformers version 5.2.0

try first install vllm==0.15.0 then install transformers==5.2.0

@zyfzjsc988
Copy link
Contributor

Successfully ran Qwen3.5 MoE GRPO (this PR example) with the following setup:
vllm: 0.15.0
mbridge: git main branch (Commit 3dd5af9)
megatron-core: 0.16.0
transformers: 5.2.0
torch: 2.9.0+cu129
flash_attn: 2.8.1
flash-linear-attention: 0.4.1

however, when i ran Qwen3.5-27B, error raise:

  File "/root/verl/verl/workers/megatron_workers.py", line 887, in compute_log_prob
    output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 105, in f
    return self.log(decorated_function, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 118, in log
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 254, in compute_log_prob
    output = self.forward_backward_batch(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 716, in forward_backward_batch
    losses_reduced = forward_backward_func(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/pipeline_parallel/schedules.py", line 636, in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
                                ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/pipeline_parallel/schedules.py", line 423, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 666, in forward_step
    output = forward_fn(
             ^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward.py", line 141, in model_forward
    output_orig = model(
                  ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/distributed/data_parallel_base.py", line 22, in forward
    return self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/module.py", line 489, in forward
    outputs = self.module(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/mbridge/models/qwen3_5/model.py", line 367, in forward
    output = self.language_model(
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/models/gpt/gpt_model.py", line 504, in forward
    preproc_output = self._preprocess(
                     ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/models/gpt/gpt_model.py", line 388, in _preprocess
    rotary_pos_emb = self.rotary_pos_emb(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/mbridge/models/qwen3_vl/rope_utils.py", line 130, in forward
    seq_expanded = seq[:, :, None, :].float()
                   ~~~^^^^^^^^^^^^^^^
IndexError: too many indices for tensor of dimension 2

@ISEEKYAN is this PR support Qwen3.5-27B? or any example for Qwen3.5-27B?

@ISEEKYAN
Copy link
Collaborator Author

ISEEKYAN commented Mar 4, 2026

@ISEEKYAN is this PR support Qwen3.5-27B? or any example for Qwen3.5-27B?

@zyfzjsc988 I only did exp on moe version, let me check if there is any bug with dense version. cc @LiuXTao

@FlyingDutchman26
Copy link

FlyingDutchman26 commented Mar 4, 2026

Successfully ran Qwen3.5 MoE GRPO (this PR example) with the following setup: vllm: 0.15.0 mbridge: git main branch (Commit 3dd5af9) megatron-core: 0.16.0 transformers: 5.2.0 torch: 2.9.0+cu129 flash_attn: 2.8.1 flash-linear-attention: 0.4.1

@zyfzjsc988 Hi, thanks for sharing your setup and experience with the Qwen3.5 MoE GRPO example!

I tried to reproduce your environment and also installed transformer-engine, but I'm still encountering an error:

ValueError: Model architectures ['Qwen3_5MoeForConditionalGeneration'] are not supported for now

I suspect that vllm==0.15.0 might not be the correct version to support this model architecture. Shall we install the nightly version or if there are any additional steps needed to make it work?

Any guidance would be greatly appreciated. Thanks in advance!

@zyfzjsc988
Copy link
Contributor

zyfzjsc988 commented Mar 4, 2026

@ISEEKYAN is this PR support Qwen3.5-27B? or any example for Qwen3.5-27B?

@zyfzjsc988 I only did exp on moe version, let me check if there is any bug with dense version. cc @LiuXTao

hi, @ISEEKYAN i fix this bug and run Qwen3.5-0.8B (Qwen3.5-27B TP4 also works) just now by adding QWEN3_5_VL config to SupportedVLM, and you can also check if there are others needed to modify for dense model.

class SupportedVLM(Enum):
    QWEN2_5_VL = "Qwen2_5_VLForConditionalGeneration"
    QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
    QWEN3_VL = "Qwen3VLForConditionalGeneration"
    QWEN3_5_MOE_VL = "Qwen3_5MoeForConditionalGeneration"
    QWEN3_5_VL = "Qwen3_5ForConditionalGeneration"

but Qwen3.5-27B still have bugs when TP=8, and i am not sure if mbridge needs modify, for example:

  File "/root/verl/verl/workers/megatron_workers.py", line 887, in compute_log_prob
    output, entropys, layers_topk_idx = self.actor.compute_log_prob(data=data, calculate_entropy=not is_lora)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 105, in f
    return self.log(decorated_function, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/utils/profiler/performance.py", line 118, in log
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 254, in compute_log_prob
    output = self.forward_backward_batch(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 716, in forward_backward_batch
    losses_reduced = forward_backward_func(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/pipeline_parallel/schedules.py", line 636, in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
                                ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/pipeline_parallel/schedules.py", line 423, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/verl/verl/workers/actor/megatron_actor.py", line 666, in forward_step
    output = forward_fn(
             ^^^^^^^^^^^
  File "/root/verl/verl/models/mcore/model_forward.py", line 141, in model_forward
    output_orig = model(
                  ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/distributed/data_parallel_base.py", line 22, in forward
    return self.module(*inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/module.py", line 489, in forward
    outputs = self.module(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/mbridge/models/qwen3_5/model.py", line 367, in forward
    output = self.language_model(
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/models/gpt/gpt_model.py", line 525, in forward
    hidden_states = self.decoder(
                    ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/transformer_block.py", line 619, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/module.py", line 352, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/transformer_block.py", line 765, in forward
    hidden_states, context = layer(
                             ^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/transformer_layer.py", line 1217, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/module.py", line 352, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/transformer_layer.py", line 513, in forward
    hidden_states, context = self._forward_attention(*args, **kwargs)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/transformer_layer.py", line 597, in _forward_attention
    attention_output_with_bias = self.self_attention(
                                 ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/mbridge/models/qwen3_5/attention.py", line 360, in forward
    core_attn_out = self._apply_output_gate(core_attn_out, gate)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/megatron/core/transformer/attention.py", line 1221, in _apply_output_gate
    gate = gate.view(*x.shape)
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[776, 1, 768]' is invalid for input of size 1191936

@zyfzjsc988
Copy link
Contributor

Successfully ran Qwen3.5 MoE GRPO (this PR example) with the following setup: vllm: 0.15.0 mbridge: git main branch (Commit 3dd5af9) megatron-core: 0.16.0 transformers: 5.2.0 torch: 2.9.0+cu129 flash_attn: 2.8.1 flash-linear-attention: 0.4.1

@zyfzjsc988 Hi, thanks for sharing your setup and experience with the Qwen3.5 MoE GRPO example!

I tried to reproduce your environment and also installed transformer-engine, but I'm still encountering an error:

ValueError: Model architectures ['Qwen3_5MoeForConditionalGeneration'] are not supported for now

I suspect that vllm==0.15.0 might not be the correct version to support this model architecture. Shall we install the nightly version or if there are any additional steps needed to make it work?

Any guidance would be greatly appreciated. Thanks in advance!

please try install vllm v0.16.1rc0

@FlyingDutchman26
Copy link

FlyingDutchman26 commented Mar 4, 2026

@zyfzjsc988 @ISEEKYAN Thank you so much for your previous help. I've successfully run the GRPO training script.

However, I encountered an issue with transformers==5.2.0:

  File ".../lib/python3.12/site-packages/transformers/modeling_rope_utils.py", line 651, in convert_rope_params_to_dict
    ignore_keys_at_rope_validation = ignore_keys_at_rope_validation | {"partial_rotary_factor"}
                                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for |: 'list' and 'set'

As a temporary workaround, I modified the transformers source code locally:

if ignore_keys_at_rope_validation is None:
    ignore_keys_at_rope_validation = set()
elif not isinstance(ignore_keys_at_rope_validation, set):
    ignore_keys_at_rope_validation = set(ignore_keys_at_rope_validation)

After this change, the training ran successfully.

I was wondering if you have encountered the same issue, or if there's a more proper fix (e.g., updating to a newer transformers version)?

Thanks again for your guidance!

@baobaohanhan21
Copy link

@zyfzjsc988 @ISEEKYAN Thank you so much for your previous help. I've successfully run the GRPO training script.

However, I encountered an issue with transformers==5.2.0:

  File ".../lib/python3.12/site-packages/transformers/modeling_rope_utils.py", line 651, in convert_rope_params_to_dict
    ignore_keys_at_rope_validation = ignore_keys_at_rope_validation | {"partial_rotary_factor"}
                                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~
TypeError: unsupported operand type(s) for |: 'list' and 'set'

As a temporary workaround, I modified the transformers source code locally:

if ignore_keys_at_rope_validation is None:
    ignore_keys_at_rope_validation = set()
elif not isinstance(ignore_keys_at_rope_validation, set):
    ignore_keys_at_rope_validation = set(ignore_keys_at_rope_validation)

After this change, the training ran successfully.

I was wondering if you have encountered the same issue, or if there's a more proper fix (e.g., updating to a newer transformers version)?

Thanks again for your guidance!

I've also successfully run the GRPO training script. modify to set(ignore_keys_at_rope_validation) is a proper fix currently. but I encountered CPU OOM when save_checkpoints.

@Code4Graph
Copy link

Code4Graph commented Mar 10, 2026

does anyone see the transformer_engine issue when running Qwen3.5-9B with GRPO? which transformer_engine can be applied when updating vllm==0.17.0 with torch 2.10.0

File "/usr/local/lib/python3.12/dist-packages/megatron/core/tensor_parallel/random.py", line 30, in <module>
    from transformer_engine.pytorch.distributed import activation_recompute_forward
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/__init__.py", line 18, in <module>
    load_framework_extension("torch")
  File "/usr/local/lib/python3.12/dist-packages/transformer_engine/common/__init__.py", line 190, in load_framework_extension
    solib = importlib.util.module_from_spec(spec)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ImportError: /usr/local/lib/python3.12/dist-packages/transformer_engine/transformer_engine_torch.cpython-312-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda29c10_cuda_check_implementationEiPKcS2_ib

@shinesun07
Copy link

shinesun07 commented Mar 11, 2026

Does anyone see this issue when running Qwen3.5-35B-A3B with GRPO?
transformers==5.3.0

File "/usr/local/lib/python3.10/dist-packages/verl/experimental/agent_loop/agent_loop.py", line 737, in _compute_position_ids
    vision_position_ids, _ = self.processor.get_rope_index(
TypeError: Qwen3VLModel.get_rope_index() missing 1 required positional argument: 'mm_token_type_ids'

@mamazi0131
Copy link

已成功运行 Qwen3.5 MoE GRPO(此 PR 示例),配置如下:vllm:0.15.0;mbridge:git 主分支(提交 3dd5af9);megatron-core:0.16.0;transformers:5.2.0;torch:2.9.0+cu129;flash_attn:2.8.1;flash-linear-attention:0.4.1

@zyfzjsc988您好,感谢您分享您在使用 Qwen3.5 MoE GRPO 示例时的设置和经验!
我尝试复现你的环境,也安装了 transformer-engine,但我仍然遇到错误:

ValueError: Model architectures ['Qwen3_5MoeForConditionalGeneration'] are not supported for now

我怀疑 vllm==0.15.0 可能不是支持此模型架构的正确版本。我们是否应该安装 nightly 版本,或者是否需要其他步骤才能使其正常工作?
非常感谢您的指导!

请尝试安装 vllm v0.16.1rc0

Thanks for your advice. Could you please provide the full experimental version? It appears that vLLM 0.16.1rc0 was compiled with Torch 2.10.

@mamazi0131
Copy link

Code4Graph

I’m experiencing the same problem on an environment running CUDA 12.9, PyTorch 2.10, Transformer Engine 2.10.0+769ed77, and vLLM 0.16.1rc0.

@Code4Graph
Copy link

Code4Graph commented Mar 11, 2026

Code4Graph

I’m experiencing the same problem on an environment running CUDA 12.9, PyTorch 2.10, Transformer Engine 2.10.0+769ed77, and vLLM 0.16.1rc0.

torch 2.10.0, cu128, transformer-engine 2.12.0, vllm 0.17.0, works for me now (GRPO without LoRA)

@Code4Graph
Copy link

Code4Graph commented Mar 11, 2026

Does anyone see this issue when running Qwen3.5-35B-A3B with GRPO? transformers==5.3.0

File "/usr/local/lib/python3.10/dist-packages/verl/experimental/agent_loop/agent_loop.py", line 737, in _compute_position_ids
    vision_position_ids, _ = self.processor.get_rope_index(
TypeError: Qwen3VLModel.get_rope_index() missing 1 required positional argument: 'mm_token_type_ids'

solve it with the following update and test it. Please correct me if i make mistakes about the update.

        # Construct mm_token_type_ids: 0=text, 1=image, 2=video
        mm_token_type_ids = torch.zeros_like(input_ids, dtype=torch.int)
        if hasattr(self.processor, "config"):
            config = self.processor.config
            if hasattr(config, "image_token_id") and config.image_token_id is not None:
                mm_token_type_ids[input_ids == config.image_token_id] = 1
            if hasattr(config, "video_token_id") and config.video_token_id is not None:
                mm_token_type_ids[input_ids == config.video_token_id] = 2

@baobaohanhan21
Copy link

how to disable thinking in qwen3.5 RL

@FlyingDutchman26
Copy link

Thanks for your advice. Could you please provide the full experimental version? It appears that vLLM 0.16.1rc0 was compiled with Torch 2.10.

Maybe you could try

pip download vllm  --pre  --index-url https://wheels.vllm.ai/nightly    --only-binary=:all: --dest {your_dest}  --no-deps

pip install {your_dest} 

@wuxibin89
Copy link
Collaborator

wuxibin89 commented Mar 11, 2026

verl release image has upgraded to vllm==0.17.0 and sglang==0.5.9 #5542

For Qwen3-VL/Qwen3 series model

@wuxibin89
Copy link
Collaborator

wuxibin89 commented Mar 13, 2026

Test examples/grpo_trainer/run_qwen3_5-35b-megatron.sh:

  • image: verlai/verl:vllm017.latest or verlai/verl:sgl059.latest
  • pip install transformers==5.3.0 flash-linear-attention
  • model: Qwen3.5-35B-A3B
  • dataset: geo3k
  • megatron: tp=2,pp=1,ep=8
image

@wuxibin89 wuxibin89 merged commit ef072ac into verl-project:main Mar 13, 2026
48 of 55 checks passed
This was referenced Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants