Skip to content

Commit c63ab78

Browse files
authored
[model] add qwen3 support (#276)
1 parent cf1d7ef commit c63ab78

4 files changed

Lines changed: 15 additions & 3 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://a
1010
## Features
1111

1212
- Supported models
13-
- Llama3/Qwen2/Qwen2.5 language models
13+
- Llama3/Qwen2/Qwen2.5/Qwen3 language models
1414
- Qwen2/Qwen2.5-VL vision language models
1515
- DeepSeek-R1 distill models
1616

examples/qwen3_4b_math_grpo.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
3+
set -x
4+
5+
export PYTHONUNBUFFERED=1
6+
7+
MODEL_PATH=Qwen/Qwen3-4B # replace it with your local file path
8+
9+
python3 -m verl.trainer.main \
10+
config=examples/config.yaml \
11+
data.max_response_length=4096 \
12+
worker.actor.model.model_path=${MODEL_PATH}

verl/models/monkey_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
def apply_ulysses_patch(model_type: str) -> None:
23-
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2"):
23+
if model_type in ("llama", "gemma", "gemma2", "mistral", "qwen2", "qwen3", "qwen3_moe"):
2424
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
2525
elif model_type in ("qwen2_vl", "qwen2_5_vl"):
2626
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2

verl/utils/flops_counter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from transformers.models.llama.configuration_llama import LlamaConfig
2222

2323

24-
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
24+
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"}
2525

2626

2727
def get_device_flops(unit: str = "T") -> float:

0 commit comments

Comments
 (0)