Skip to content

Commit 0528ba1

Browse files
as12138Chendong98zheliuyu
authored
[NPU] feat: Support FSDP worker and vLLM Ascend (#332)
For developers, you can follow the docs: docs/ascend/ascend.rst This pr is committed for supporting Ascend NPU backend. Co-authored-by: Chendong98 [[email protected]](mailto:[email protected]) Co-authored-by: zheliuyu <[email protected]> Co-authored-by: celestialli [[email protected]](mailto:[email protected]) In this pr, we add the capability to determine the type of NPU device and we also add a new script for training on NPU. These are change lists: 1. pyproject.toml change verison of vllm 2. requirements-npu.txt requirements for NPU 3. verl/bert_padding.py Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py 4. verl/single_controller/ray/base.py 5. verl/third_party/vllm/vllm_spmd/dtensor_weight_loaders.py 6. verl/trainer/fsdp_sft_trainer.py 7. verl/utils/flops_counter.py 8. verl/utils/fsdp_utils.py 9. verl/workers/actor/dp_actor.py 10. verl/workers/critic/dp_critic.py 11. verl/workers/fsdp_workers.py 12. verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py 13. verl/workers/sharding_manager/fsdp_vllm.py 14. verl/utils/device.py get device type for different device 15. docs/ascend/ascend.md Here are our roadmap: **RoadMap** - [x] sft - [x] ppo - [x] grpo News [2025.03.31] Add result of SFT and GRPO. Qwen2-7B-Instruct was tested on 2*8 devices, and many params related to batch_size need to be reduced. So this result is only for reference. We will announce the reward results of the default params as soon as sleep mode is supported. [2025.03.03] Modify the adaptation method of Ray [2025.02.25] The PPO algorithm is supported for training on NPU with the FSDP backend. [2025.02.23] The SFT algorithm is supported for training on NPU with the FSDP backend. [2025.02.21] The GRPO algorithm is supported for training on NPU with the FSDP backend. Requirements We use this PR testing on Ascend NPU and GPU to ensure the same codes can run on different devices. The device information is 8 Atlas 800T A2 and 8 A100. Other software information is shown in the following table. | Software | Version | |:-------|-------:| | transformers | 4.47.1 | | accelerate | 1.3.0 | | torch_npu | 2.5.1.rc1| |CANN | 8.1.RC1 (Not Released)| About mean error Due to differences in hardware structure, we cannot guarantee that the loss of Ascend NPU is exactly the same as that of the GPU. According to our experience, the loss differences less than 2% is acceptable. If the loss difference is greater than 2%, we will try to fix it. The calculation formula is as follows. ![loss_comparison](https://github.com/user-attachments/assets/4f62f713-9240-4324-bf7d-3ae59fc85b05) N represents the number of training steps. For more information, please refer to [Calculation accuracy description](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/LMaccuracy_0001.html) --------- Co-authored-by: Chendong98 <[email protected]> Co-authored-by: zheliuyu <[email protected]>
1 parent a7b2e29 commit 0528ba1

30 files changed

+529
-109
lines changed

.github/workflows/e2e_ascend.yml

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ jobs:
2626
test:
2727
name: verl Ascend test (self-host)
2828
runs-on: [self-hosted, npu-0]
29-
timeout-minutes: 5 # Increase this timeout value as needed
29+
timeout-minutes: 30 # Increase this timeout value as needed
3030
container:
31-
image: quay.io/ascend/cann:8.0.0-910b-ubuntu22.04-py3.10
31+
image: quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10
3232
volumes:
3333
- /usr/local/dcmi:/usr/local/dcmi
3434
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi
@@ -42,13 +42,56 @@ jobs:
4242
--device /dev/hisi_hdc
4343
--privileged
4444
--network "host"
45+
--shm-size 2g
46+
env:
47+
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
48+
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
49+
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
50+
HF_ENDPOINT: "https://hf-mirror.com"
51+
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
4552
steps:
4653
- name: Check npu and CANN info
4754
run: |
4855
cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info
4956
npu-smi info
5057
- name: Checkout volcengine/verl repo
5158
uses: actions/checkout@v4
52-
- name: Run test
59+
- name: Install torch
5360
run: |
54-
lscpu
61+
pip install torch==2.5.1+cpu --index-url https://download.pytorch.org/whl/cpu
62+
pip install torch-npu==2.5.1
63+
pip install /usr/local/Ascend/ascend-toolkit/latest/lib64/te-0.4.0-py3-none-any.whl
64+
- name: Install vllm
65+
run: |
66+
apt-get update && apt-get install -y git
67+
git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm.git vllm-npu
68+
cd vllm-npu
69+
pip install -r requirements-build.txt
70+
VLLM_TARGET_DEVICE=empty pip install -e . --extra-index https://download.pytorch.org/whl/cpu/
71+
- name: Install vllm-ascend
72+
run: |
73+
pip list
74+
pip show torch
75+
git clone -b v0.7.3 --depth 1 https://github.com/vllm-project/vllm-ascend.git
76+
cd vllm-ascend
77+
export COMPILE_CUSTOM_KERNELS=1
78+
python setup.py install
79+
- name: Install the current repository
80+
run: |
81+
pip3 install hf_transfer peft
82+
pip3 install -r requirements-npu.txt
83+
pip install -e .
84+
- name: Prepare gsm8k dataset
85+
run: |
86+
ray stop --force
87+
python3 examples/data_preprocess/gsm8k.py
88+
- name: Running gsm8k e2e training tests with LoRA on ASCEND NPU
89+
run: |
90+
ray stop --force
91+
bash tests/e2e/sft/run_sft.sh
92+
rm -rf $HOME/ckpts
93+
- name: Running gsm8k e2e training tests with GRPO on ASCEND NPU
94+
run: |
95+
ray stop --force
96+
bash tests/npu/run_qwen2_5_05b_grpo.sh
97+
rm -rf $HOME/ckpts

docs/ascend/ascend_vllm073.rst

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
verl x Ascend
2+
========
3+
4+
我们在 verl 上增加对华为昇腾设备的支持。
5+
6+
硬件支持
7+
=======
8+
9+
* Atlas 800T A2
10+
11+
* Atlas 200T A2 Box16
12+
13+
安装
14+
=======
15+
16+
环境准备
17+
------
18+
19+
+--------------+----------+
20+
| 软件 | 版本 |
21+
+-----------+-------------+
22+
| Python | == 3.10 |
23+
| torch | == 2.5.1 |
24+
| torch_npu | == 2.5.1rc1 |
25+
| CANN | == 8.1.RC1 |
26+
+-----------+-------------+
27+
28+
1. 使用 vLLM,需遵循 vllm-ascend 的安装教程 <https://vllm-ascend.readthedocs.io/en/v0.7.3/installation.html>。
29+
2. 为了能够在 ASCEND NPU 上正常使能 flash_attention_2, transformers 版本需要大于等于 4.52.0。
30+
3. 目前支持 SFT 与 LLM 模型的 GRPO 训练,VLM模型的 GRPO 训练因为 vllm-ascend 的问题将会在后续支持,涉及到的issue为:
31+
32+
https://github.com/vllm-project/vllm-ascend/issues/809
33+
34+
https://github.com/vllm-project/vllm-ascend/issues/825
35+
36+
源码安装
37+
------
38+
39+
.. code-block::
40+
git clone https://github.com/volcengine/verl.git
41+
cd verl
42+
pip install -r requirements-npu.txt
43+
pip install -e .
44+
45+
vLLM
46+
------
47+
48+
为了保证能够在 verl 上正常使用 vLLM,需要安装 vLLM Ascend 插件(`vllm-ascend`)。关于在华为昇腾上支持的 vLLM 版本以及和 vLLM Ascend 的配套关系请参考`安装教程 <https://vllm-ascend.readthedocs.io/en/v0.7.3/installation.html>`_。
49+
50+
其他第三方库说明
51+
------
52+
53+
+--------------+--------+
54+
| 软件 | 说明 |
55+
+--------------+--------+
56+
| flash_attn | 不支持 |
57+
+--------------+--------+
58+
| liger-kernel | 不支持 |
59+
+--------------+--------+
60+
61+
精度对比
62+
------
63+
64+
根据经验,对于SFT等微调算法,我们期望在相同配置下,在华为昇腾设备上的 Loss 与英伟达 GPU 的 Loss 平均绝对误差小于等于 2%,具体计算方式如下:
65+
66+
.. image:: https://github.com/eric-haibin-lin/verl-community/tree/main/docs/loss_comparison.png
67+
:alt: loss_comparison
68+
69+
其中,N 表示训练的步数。更多信息请参考[精度计算说明](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/LMaccuracy_0001.html)。
70+
71+
根据经验,对于GRPO等强化学习算法,我们期望在相同配置下,在华为昇腾设备上的 reward 与英伟达 GPU 的 reward 平均绝对误差小于等于 4%,具体计算参考 Loss 计算。
72+
73+
进展
74+
------
75+
76+
+--------+--------+
77+
| 算法 | 进展 |
78+
+--------+--------+
79+
| SFT | 已支持 |
80+
+--------+--------+
81+
| GRPO | 已支持 |
82+
+--------+--------+

recipe/dapo/main_dapo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import ray
2222

2323
from .dapo_ray_trainer import RayDAPOTrainer
24+
from verl.utils.device import is_cuda_available
2425

2526

2627
def get_custom_reward_fn(config):

recipe/sppo/main_sppo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from verl.trainer.ppo.reward import load_reward_manager
2626

2727
from .sppo_ray_trainer import RaySPPOTrainer
28+
from verl.utils.device import is_cuda_available
2829

2930

3031
@hydra.main(config_path="config", config_name="sppo_trainer", version_base=None)
@@ -140,6 +141,7 @@ def run(self, config):
140141
ray_worker_group_cls=ray_worker_group_cls,
141142
reward_fn=reward_fn,
142143
val_reward_fn=val_reward_fn,
144+
device_name="cuda" if is_cuda_available else "npu",
143145
)
144146
trainer.init_workers()
145147
trainer.fit()

recipe/sppo/sppo_ray_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
val_dataset: Optional[Dataset] = None,
8787
collate_fn=None,
8888
train_sampler: Optional[Sampler] = None,
89+
device_name="cuda",
8990
):
9091
self.tokenizer = tokenizer
9192
self.processor = processor
@@ -105,6 +106,7 @@ def __init__(
105106
self.use_rm = Role.RewardModel in role_worker_mapping
106107
self.ray_worker_group_cls = ray_worker_group_cls
107108
self.validation_generations_logger = ValidationGenerationsLogger()
109+
self.device_name = device_name
108110

109111
# define in-reward KL control
110112
# kl loss control currently not suppoorted

requirements-npu.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# requirements.txt records the full set of dependencies for development
2+
accelerate
3+
codetiming
4+
datasets
5+
dill
6+
hydra-core
7+
numpy
8+
pandas
9+
peft
10+
pyarrow>=15.0.0
11+
pybind11
12+
pylatexenc
13+
ray
14+
tensordict<=0.6.2
15+
transformers>=4.52.0
16+
wandb
17+
mathruler
18+
torchdata
19+
einops
20+
qwen_vl_utils

tests/npu/run_qwen2_5_05b_grpo.sh

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
set -x
2+
3+
export VLLM_ATTENTION_BACKEND=XFORMERS
4+
5+
python3 -m verl.trainer.main_ppo \
6+
algorithm.adv_estimator=grpo \
7+
data.train_files=$HOME/data/gsm8k/train.parquet \
8+
data.val_files=$HOME/data/gsm8k/test.parquet \
9+
data.train_batch_size=128 \
10+
data.max_prompt_length=512 \
11+
data.max_response_length=128 \
12+
data.filter_overlong_prompts=True \
13+
data.truncation='error' \
14+
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
15+
actor_rollout_ref.actor.optim.lr=1e-6 \
16+
actor_rollout_ref.model.use_remove_padding=False \
17+
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
18+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \
19+
actor_rollout_ref.actor.use_kl_loss=True \
20+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
21+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
22+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
23+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
24+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
25+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=40 \
26+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
27+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
28+
actor_rollout_ref.rollout.name=vllm \
29+
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
30+
actor_rollout_ref.rollout.n=5 \
31+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=40 \
32+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
33+
algorithm.kl_ctrl.kl_coef=0.001 \
34+
trainer.critic_warmup=0 \
35+
trainer.logger=['console'] \
36+
trainer.project_name='verl_grpo_example_gsm8k' \
37+
trainer.experiment_name='qwen2_7b_function_rm' \
38+
trainer.n_gpus_per_node=8 \
39+
trainer.nnodes=1 \
40+
trainer.save_freq=-1 \
41+
trainer.test_freq=5 \
42+
trainer.total_epochs=1 $@

tests/npu/run_qwen2_5_32b_grpo.sh

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
set -x
2+
3+
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
4+
# export VLLM_ATTENTION_BACKEND=XFORMERS
5+
6+
python3 -m verl.trainer.main_ppo \
7+
algorithm.adv_estimator=grpo \
8+
data.train_files=$HOME/data/gsm8k/train.parquet \
9+
data.val_files=$HOME/data/gsm8k/test.parquet \
10+
data.train_batch_size=1024 \
11+
data.max_prompt_length=1024 \
12+
data.max_response_length=1024 \
13+
data.filter_overlong_prompts=True \
14+
data.truncation='error' \
15+
actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
16+
actor_rollout_ref.actor.optim.lr=1e-6\
17+
actor_rollout_ref.model.use_remove_padding=False \
18+
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
19+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
20+
actor_rollout_ref.actor.use_kl_loss=True \
21+
actor_rollout_ref.actor.entropy_coeff=0 \
22+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
23+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
24+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
25+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
26+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
27+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
28+
actor_rollout_ref.rollout.tensor_model_parallel_size=8 \
29+
actor_rollout_ref.rollout.name=vllm \
30+
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
31+
actor_rollout_ref.rollout.n=5 \
32+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
33+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
34+
algorithm.use_kl_in_reward=False \
35+
trainer.critic_warmup=0 \
36+
trainer.logger=['console'] \
37+
trainer.project_name='verl_grpo_example_gsm8k' \
38+
trainer.experiment_name='qwen2_5_32b_function_rm' \
39+
trainer.n_gpus_per_node=16 \
40+
trainer.nnodes=2 \
41+
trainer.save_freq=-1 \
42+
trainer.test_freq=10 \
43+
trainer.total_epochs=15 $@

tests/npu/run_qwen2_5_7b_grpo.sh

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
set -x
2+
3+
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
4+
# export VLLM_ATTENTION_BACKEND=XFORMERS
5+
6+
python3 -m verl.trainer.main_ppo \
7+
algorithm.adv_estimator=grpo \
8+
data.train_files=$HOME/data/gsm8k/train.parquet \
9+
data.val_files=$HOME/data/gsm8k/test.parquet \
10+
data.train_batch_size=1024 \
11+
data.max_prompt_length=1024 \
12+
data.max_response_length=1024 \
13+
data.filter_overlong_prompts=True \
14+
data.truncation='error' \
15+
actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \
16+
actor_rollout_ref.actor.optim.lr=5e-8 \
17+
actor_rollout_ref.model.use_remove_padding=False \
18+
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
19+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
20+
actor_rollout_ref.actor.use_kl_loss=True \
21+
actor_rollout_ref.actor.entropy_coeff=0 \
22+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
23+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
24+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
25+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
26+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
27+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
28+
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
29+
actor_rollout_ref.rollout.name=vllm \
30+
actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \
31+
actor_rollout_ref.rollout.n=5 \
32+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
33+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
34+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
35+
algorithm.use_kl_in_reward=False \
36+
trainer.critic_warmup=0 \
37+
trainer.logger=['console'] \
38+
trainer.project_name='verl_grpo_example_gsm8k' \
39+
trainer.experiment_name='qwen2_5_7b_function_rm' \
40+
trainer.n_gpus_per_node=16 \
41+
trainer.nnodes=1 \
42+
trainer.save_freq=-1 \
43+
trainer.test_freq=5 \
44+
trainer.total_epochs=5 $@

verl/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414

1515
import logging
1616
import os
17+
import pkg_resources
1718

19+
from pkg_resources import DistributionNotFound
20+
from packaging.version import parse as parse_version
1821
from .protocol import DataProto
1922
from .utils.logging_utils import set_basic_config
23+
from .utils.device import is_npu_available
2024

2125
version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))
2226

@@ -38,3 +42,17 @@
3842
from modelscope.utils.hf_util import patch_hub
3943

4044
patch_hub()
45+
46+
if is_npu_available:
47+
package_name = 'transformers'
48+
required_version_spec = '4.51.0'
49+
try:
50+
installed_version = pkg_resources.get_distribution(package_name).version
51+
installed = parse_version(installed_version)
52+
required = parse_version(required_version_spec)
53+
54+
if not installed >= required:
55+
raise ValueError(f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, current version is {installed}.")
56+
except DistributionNotFound:
57+
raise ImportError(
58+
f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}")

verl/models/transformers/qwen2_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434

3535
try:
36-
from flash_attn import flash_attn_func, flash_attn_varlen_func
36+
from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func
3737

3838
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
3939
except ImportError:

0 commit comments

Comments
 (0)