Skip to content

Commit edd34aa

Browse files
authored
[RL] RL support varlen flash mask (#10490)
* update rl * 删除optimizer的timer时间 * 修改rollout_continue_batching_batch_size为rollout_max_num_seqs, quant_type为rollout_quant_type * 更新global_mini_batch_size * fix entropy_coeff dtype error * add einops requirements * 更新文档 * user rm padding so we donot need to pad_to_multiple_of tp degree * update rollout_max_num_seqs * squeeze unsqueeze * shuffle * fix training bug * fix missing pad * add dataloader_shuffle args
1 parent d6182ff commit edd34aa

19 files changed

+630
-340
lines changed

llm/alignment/ppo/README.md

+21-24
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,21 @@ python setup_cuda.py install
3636

3737
### 字段说明
3838

39-
- src (list(str)): 用户对话内容,可能会包含 markup 内容,如 [<search-res>]
40-
- tgt (list(str)): 除了最后一轮的系统多轮回复内容,以对话轮次排列,可能会包含 markup 内容,如 [<search>];注意:len(tgt)==len(src)-1
39+
- src (list(str)): 经过 chat_template 处理后的 prompt 输入
40+
- tgt (list(str)): 标签内容;
4141

4242
### 数据示例
4343

4444
```json
4545
{
46-
"src": [
47-
"需要你帮我写几个有创意的广告语来打开市场。",
48-
"目标用户是年轻人,追求时尚、个性和自我。"
49-
],
50-
"tgt": [
51-
"当然!我很乐意帮助你创作几个有创意的广告语来推广你的新洗发露。请告诉我一些关于你的产品的特点,目标受众以及你希望传达的核心信息,我会根据这些信息为你提供几个创意的广告语。"
52-
]
46+
"src": ["<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within <answer> </answer> tags. i.e., <answer> (1) Zoey is a knight\n(2) ... </answer>.\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 3 inhabitants: Michael, Zoey, and Ethan. Michael was heard saying, \"Ethan is a knight if and only if Michael is a knight\". \"Zoey is a knight or Ethan is a knight,\" Zoey mentioned. Ethan asserted: \"Michael is a knave if and only if Zoey is a knave\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n<think>"],
47+
"tgt": ["(1) Michael is a knight\n(2) Zoey is a knight\n(3) Ethan is a knight"]
5348
}
5449
```
5550

5651

5752
### PPO & GRPO 数据准备
58-
53+
我们提供了一版使用 `Qwen/Qwen2.5-7B-Instruct-1M``chat template`预处理后的[KK 数据集](https://hf-mirror.com/datasets/K-and-K/knights-and-knaves)
5954
```
6055
wget https://paddlenlp.bj.bcebos.com/datasets/examples/ppo-kk.tgz && tar zxf ppo-kk.tgz
6156
```
@@ -66,9 +61,6 @@ wget https://paddlenlp.bj.bcebos.com/datasets/examples/ppo-kk.tgz && tar zxf ppo
6661

6762
我们采用的配置文件在放置在`llm/config/llama/ppo_argument.json``llm/config/llama/grpo_argument.json`中,同时我们提供了详细参数释义如下:
6863

69-
- `train_task_config`: 训练数据 config, 请以`config/task_ppo.json`为例
70-
- `eval_task_config`: 评估数据 config, 请以`config/task_ppo.json`为例
71-
- `ptx_task_config`: SFT 辅助数据, 请以`config/task_sft.json`为例,默认为""
7264
- `actor_model_name_or_path`: PPO 中 actor-model 和 reference-model 模型本地的模型路径
7365
- `reward_model_name_or_path`: PPO 中 reward-model 和 critic-model 模型本地的模型路径
7466
- `use_fusemt`: 是否通过 FustMT 加速生成,默认为 True
@@ -95,8 +87,7 @@ wget https://paddlenlp.bj.bcebos.com/datasets/examples/ppo-kk.tgz && tar zxf ppo
9587
- `critic_weight_decay`: Critic 模型除了所有 bias 和 LayerNorm 权重之外,应用于所有层的权重衰减数值。(`float`,可选,默认为 0.0)
9688
- `max_prompt_len`: 生成样本时的最大生成长度, max_length 调大会增加生成时间,并且增加显存占用。注意:
9789
max_dec_len + max_prompt_len 应当小于 max_seq_len。
98-
- `per_device_prompt_batch_size`: PPO 生成样本时的批处理大小,同 micro batch size,即满足 global_batch_size = dp(data parallel)* sharding * micro batch size。batch_size 调大会增加生成时间,并且增加显存占用
99-
- `per_device_train_batch_size`: 训练 batch 大小, 当前为了优化性能设为1,请避免更改
90+
- `per_device_train_batch_size`: 训练 batch 大小
10091
- `per_device_eval_batch_size`: 评估 batch 大小。
10192
- `max_steps`: 总的训练步数
10293
- `eval_steps`: 模型评估的间隔步数
@@ -109,13 +100,8 @@ max_dec_len + max_prompt_len 应当小于 max_seq_len。
109100
- `fp16`: 使用 float16 精度进行模型训练和推理。
110101
- `bf16`: 使用 bfloat16 精度进行模型训练和推理。
111102
- `fp16_opt_level`: float16 精度训练模式,`O2`表示纯 float16 训练
112-
113-
114-
<!-- ### PPO 训练命令
115-
116-
```shell
117-
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_ppo.py llm/config/llama/ppo_argument.json
118-
``` -->
103+
- `balance_batch`:该参数用于指定是否在数据并行场景下,对批次内的 token 数量进行均衡分配。若设置为 True,系统将尝试在不同并行设备间平衡 token 的分布;若设置为 False(默认值),则不进行此类均衡操作。
104+
- `use_remove_padding`:此参数决定是否在训练过程中去除输入数据中的 padding 部分。启用该选项(设置为 True)可有效提高训练过程中有效 token 的占比,从而提升训练效率;若设置为 False(默认值),则保留输入数据中的 padding。
119105

120106
### GRPO 训练命令
121107
```shell
@@ -130,8 +116,19 @@ python reward_server.py
130116
```shell
131117
export PYTHONPATH=your_PaddleNLP_path/:$PYTHONPATH
132118
export PYTHONPATH=your_PaddleNLP_path/llm:$PYTHONPATH
133-
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_ppo.py ../../config/qwen/grpo_argument.json
134-
# python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_ppo.py ../../config/llama/grpo_argument.json
119+
120+
export FLAGS_set_to_1d=False
121+
export NVIDIA_TF32_OVERRIDE=0
122+
export FLAGS_dataloader_use_file_descriptor=False
123+
export HF_DATASETS_DOWNLOAD_TIMEOUT=1
124+
export FLAGS_gemm_use_half_precision_compute_type=False
125+
export FLAGS_force_cublaslt_no_reduced_precision_reduction=True
126+
127+
export FLAGS_mla_use_tensorcore=0
128+
export FLAGS_cascade_attention_max_partition_size=2048
129+
130+
python -u -m paddle.distributed.launch --devices "0,1,2,3" run_ppo.py ../../config/qwen/grpo_argument.yaml
131+
# python -u -m paddle.distributed.launch --devices "0,1,2,3" run_ppo.py ../../config/llama/grpo_argument.yaml
135132
```
136133

137134
### 在线监控
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Launch Reward HTTP Server."""
16+
17+
import argparse
18+
import json
19+
import logging
20+
import threading
21+
import traceback
22+
from typing import List
23+
24+
import uvicorn
25+
from fastapi import FastAPI
26+
from pydantic import BaseModel
27+
28+
29+
class Request(BaseModel):
30+
"""The request for RM server."""
31+
32+
src: List[str]
33+
tgt: List[str]
34+
response: List[str]
35+
36+
37+
class Response(BaseModel):
38+
"""The response for RM server."""
39+
40+
error_code: int = 0
41+
error_msg: str = "Success"
42+
score: List[float] = None
43+
44+
45+
def compute_score(
46+
solution_str: str, ground_truth: str, query=None, format_reward: int = 1, answer_reward: float = 1.0
47+
):
48+
score = float(1.0)
49+
print(
50+
f"==============================================================={ground_truth}=========================================================================="
51+
)
52+
print(f"score {score}, solution_str\n", solution_str)
53+
print(
54+
"================================================================================================================================================="
55+
)
56+
return score
57+
58+
59+
def setup_args():
60+
"""Setup inerance server arguments."""
61+
parser = argparse.ArgumentParser()
62+
parser.add_argument("--port", type=int, default=8731)
63+
parser.add_argument("--log_file", type=str, default="rm_server.log")
64+
args = parser.parse_args()
65+
return args
66+
67+
68+
def server(args):
69+
"""Launch RM server."""
70+
app = FastAPI()
71+
lock = threading.Lock()
72+
73+
logging.basicConfig(
74+
level=logging.INFO,
75+
filename=args.log_file,
76+
filemode="w",
77+
format="%(asctime)s - %(message)s",
78+
)
79+
80+
@app.post("/")
81+
async def _server(request: Request) -> Response:
82+
lock.acquire()
83+
logging.info(f"Request: {request}")
84+
try:
85+
all_result = []
86+
if len(request.tgt) != len(request.response) or len(request.tgt) != len(request.src):
87+
raise ValueError("The length of response, tgt, and src should be equal.")
88+
for i in range(len(request.response)):
89+
reward = compute_score(request.response[i], request.tgt[i], request.src[i])
90+
all_result.append(reward)
91+
output = {
92+
"error_code": 0,
93+
"error_msg": "Success",
94+
"score": all_result,
95+
}
96+
except Exception as err:
97+
logging.error(f"Server error: when process {request}\n{traceback.format_stack()}")
98+
output = {
99+
"error_code": 500,
100+
"error_msg": f"{err}",
101+
"score": [0] * len(request.tgt),
102+
}
103+
logging.info(f"Response: {json.dumps(output, indent=2, ensure_ascii=False)}")
104+
lock.release()
105+
return output
106+
107+
uvicorn.run(app, host="0.0.0.0", port=args.port)
108+
109+
110+
if __name__ == "__main__":
111+
args = setup_args()
112+
server(args)

llm/alignment/ppo/run_ppo.py

-3
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def create_actor_models(
100100
actor_model_config.set_attn_func = True
101101
actor_model_config.max_position_embeddings = data_args.max_length
102102
actor_model_config.use_sparse_head_and_loss_fn = False
103-
actor_model_config.fused_linear = model_args.fused_linear
104-
actor_model_config.use_fused_rms_norm = training_args.use_fused_rms_norm
105103
actor_model_config.seq_length = data_args.max_length
106104
actor_model_config.max_sequence_length = data_args.max_length
107105
print(f"Loading Actor model with config:\n\t{actor_model_config}\n")
@@ -172,7 +170,6 @@ def create_reward_models(
172170
LlmMetaConfig.set_llm_config(reward_model_config, training_args)
173171
reward_model_config.max_position_embeddings = data_args.max_length
174172
reward_model_config.use_sparse_head_and_loss_fn = False
175-
reward_model_config.fused_linear = model_args.fused_linear
176173
print(f"Loading Reward model with config:\n\t{reward_model_config}\n")
177174

178175
config = copy.deepcopy(reward_model_config)

llm/config/llama/grpo_argument.yaml

+19-21
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ reward_server: "http://127.0.0.1:8731" # The address of the reward model server
1111
logging_dir: grpo-logs # Directory for logging
1212
logging_steps: 1 # Number of steps between logging
1313
output_dir: "qwen2.5-7b-kk-dataset-grpo/checkpoints" # Directory for output ckpts
14-
report_to: "wandb" # Supported reporting options: "all", "wandb", "tensorboard", "visualdl"(default), "none"
14+
report_to: "visualdl" # Supported reporting options: "all", "wandb", "tensorboard", "visualdl"(default), "none"
1515
wandb_http_proxy: "http://127.0.0.1:8962" # HTTP proxy for wandb
1616
run_name: "qwen2.5-7b-kk-dataset-grpo" # Name of the run
1717

@@ -22,12 +22,13 @@ prompt_key: "src" # Key for the prompt in the dataset
2222
response_key: "tgt" # Key for the response in the dataset
2323
dataloader_drop_last: true # Whether to drop the last incomplete batch in the DataLoader
2424
balance_batch: true # Whether to balance batch size across dataset_world_size
25+
use_remove_padding: true # Whether to remove padding tokens in the input
2526

2627
# distributed training args
2728
tensor_parallel_degree: 2 # Degree of tensor parallelism
2829
sequence_parallel: true # Whether to enable sequence parallelism
29-
sharding_parallel_degree: 1 # Degree of sharding parallelism
30-
sharding: "stage2" # Sharding strategy, e.g., "stage1" or "stage2"
30+
sharding_parallel_degree: -1 # Degree of sharding parallelism
31+
sharding: "stage1" # Sharding strategy, e.g., "stage1" or "stage2"
3132
sharding_parallel_config: "enable_release_grads" # Configuration for sharding parallelism
3233
pipeline_parallel_degree: 1 # Degree of pipeline parallelism
3334
virtual_pp_degree: 1 # Degree of virtual pipeline parallelism
@@ -39,24 +40,23 @@ min_dec_len: 32 # Minimum length of the response
3940
top_p: 1.0 # Top-p sampling parameter
4041
temperature: 0.7 # Temperature parameter for sampling
4142
repetition_penalty: 1.0 # Repetition penalty parameter
42-
# rollout_use_dynamic_insert: 1 # Whether to use dynamic insert for rollout
43-
# rollout_continue_batching_batch_size: 32 # Base batch size for dynamic insert
44-
quant_type: "" # Quantization type, e.g., "weight_only_int8"
43+
rollout_max_num_seqs: 32 # The maximum number of sequences that can be processed in a single inference
44+
rollout_quant_type: "" # Quantization type, e.g., "weight_only_int8"
4545

4646
# training args
4747
do_train: true # Whether to perform training
4848
seed: 42 # Random seed for reproducibility
49-
global_batch_size: 2 # Global batch size for training
50-
mini_batch_size: 2 # Mini-batch size for training
49+
global_batch_size: 4 # Global batch size for training
50+
global_gen_batch_size: -1 # Global generation batch size for dynamic sampling
51+
global_mini_batch_size: -1 # Mini-batch size for training
5152
rollout_n: 8 # Number of rollouts
5253
update_iters: 1 # Number of training iterations for rollout samples
53-
per_device_rollout_batch_size: 1 # Rollout batch size per device
5454
per_device_logprob_batch_size: 8 # Log probability batch size per device
5555
per_device_reward_batch_size: 8 # Reward batch size per device
5656
per_device_value_batch_size: 8 # Value batch size per device
5757
per_device_train_batch_size: 8 # Training batch size per device
5858
# gradient_accumulation_steps: 1 # Gradient accumulation steps (auto-calculated)
59-
num_train_epochs: 3 # Number of training epochs
59+
num_train_epochs: 6 # Number of training epochs
6060
max_length: 4608 # Maximum length for training, should be larger than max_prompt_len + max_dec_len
6161
learning_rate: 5e-7 # Learning rate for training
6262
lr_scheduler_type: "constant" # Learning rate scheduler type
@@ -65,15 +65,15 @@ adam_beta1: 0.9 # AdamW optimizer beta1
6565
adam_beta2: 0.999 # AdamW optimizer beta2
6666
adam_epsilon: 1e-8 # AdamW optimizer epsilon
6767
max_grad_norm: 1.0 # Maximum gradient norm for clipping
68-
max_steps: 3600 # Maximum number of training steps
68+
max_steps: -1 # Maximum number of training steps
6969
save_steps: 300 # Number of steps between model saves
7070
save_strategy: "steps" # Strategy for saving models
7171
ignore_save_lr_and_optim: true # Whether to ignore saving learning rate and optimizer state (leave empty if not specified)
7272
disable_tqdm: true # Whether to disable tqdm progress bar
7373

7474
# RL args
7575
kl_coeff: 0.0 # KL coefficient
76-
kl_loss_coeff: 0.0 # KL loss coefficient
76+
kl_loss_coeff: 0.001 # KL loss coefficient
7777
pg_loss_coeff: 1.0 # Policy gradient loss coefficient
7878
entropy_coeff: 0.0 # Entropy coefficient
7979
clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
@@ -84,12 +84,11 @@ enable_overlong_reward_buffer: false # Whether to enable overlong reward buffer
8484
overlong_reward_buffer: 256 # The length of the overlong reward buffer
8585
overlong_penalty_factor: 1.0 # The penalty factor for overlong reward buffer
8686
clip_range_value: 5.0 # The clipping range for the output of the value model. The value is clipped into [-clip_range_value, clip_range_value].
87-
normalize_reward: true # Whether to normalize reward
88-
normalize_advantage: true # Whether to normalize advantage
87+
normalize_reward: false # Whether to normalize reward
88+
normalize_advantage: false # Whether to normalize advantage
8989
dynamic_sampling: false # Whether to use dynamic sampling, which is introcuded in DAPO algorithm https://arxiv.org/abs/2503.14476
90-
per_device_sample_batch_size: 1 # Sample batch size per device for dynamic sampling
9190
max_gen_batches: 2 # Maximum number of generation batches for dynamic sampling
92-
use_fp32_compute: false # Whether to use fp32 to compute xx_log_prob,rewards, advantages and loss
91+
use_fp32_compute: true # Whether to use fp32 to compute xx_log_prob,rewards, advantages and loss
9392

9493
# eval args
9594
do_eval: true # Whether to perform evaluation
@@ -99,11 +98,10 @@ eval_steps: 20 # Number of steps between evaluations
9998

10099
# device memory optimization args
101100
use_flash_attention: true # Whether to use fused attention operations
102-
use_fused_rms_norm: true # Whether to use fused RMS norm operations
103-
use_fused_rope: true # Whether to use fused rope operations
104-
use_fused_head_and_loss_fn: false # Whether to use fused head and loss function
105-
use_fused_linear: false # Whether to use fused linear operations, which needs to install fused_ln in slm/model_zoo/gpt-3/external_ops
106-
fused_linear: false # Whether to use fused_gemm_epilogue
101+
use_fused_rms_norm: true # Whether to use fused RMS norm operations, which needs to install fused_ln in slm/model_zoo/gpt-3/external_ops
102+
use_fused_rope: false # Whether to use fused rope operations
103+
use_fused_head_and_loss_fn: true # Whether to use fused head and loss function
104+
use_fused_linear: true # Whether to use fused linear operations
107105
recompute: true # Whether to enable gradient checkpointing for memory optimization
108106
recompute_use_reentrant: true # Whether to use reentrant recompute
109107
recompute_granularity: "full" # Granularity of recompute

0 commit comments

Comments
 (0)