GRPO(Group Relative Policy Optimization,组相对策略优化)是 PPO(Proximal Policy Optimization,近端策略优化)算法的一种变体。与 PPO 不同,GRPO 省略了价值函数估计器。在 GRPO 中,对于每个状态
以下是详细的使用文档和示例:
- 训练环境:
- 参考 Paddle 官网安装 PaddlePaddle-GPU, 要求 PaddlePaddle>=3.0
- clone 并安装 PaddleNLP
git clone https://github.com/PaddlePaddle/PaddleNLP.git- 安装 paddlenlp_ops 推理算子,参考 PaddleNLP/csrc 进行安装(必需)
cd your_PaddleNLP_path/csrc
python setup_cuda.py install- 安装 fused_ln 和 fast_ln 训练算子,参考 PaddleNLP/slm/model_zoo/gpt-3/external_ops (必须)
cd your_PaddleNLP_path/slm/model_zoo/gpt-3/external_ops
python setup.py install| 模型系列 | 模型名称 |
|---|---|
| Qwen1.5 | Qwen/Qwen1.5-0.5B, Qwen/Qwen1.5-0.5B-Chat, Qwen/Qwen1.5-1.8B, Qwen/Qwen1.5-1.8B-Chat, Qwen/Qwen1.5-4B, Qwen/Qwen1.5-4B-Chat, Qwen/Qwen1.5-7B, Qwen/Qwen1.5-7B-Chat, Qwen/Qwen1.5-14B, Qwen/Qwen1.5-14B-Chat, Qwen/Qwen1.5-32B, Qwen/Qwen1.5-32B-Chat |
| Qwen2 | Qwen/Qwen2-0.5B, Qwen/Qwen2-0.5B-Instruct, Qwen/Qwen2-1.5B, Qwen/Qwen2-1.5B-Instruct, Qwen/Qwen2-7B, Qwen/Qwen2-7B-Instruct, Qwen/Qwen2-72B, Qwen/Qwen2-72B-Instruct, Qwen/Qwen2-57B-A14B, Qwen/Qwen2-57B-A14B-Instruct |
| Qwen2-Math | Qwen/Qwen2-Math-1.5B, Qwen/Qwen2-Math-1.5B-Instruct, Qwen/Qwen2-Math-7B, Qwen/Qwen2-Math-7B-Instruct |
| Qwen2.5 | Qwen/Qwen2.5-0.5B, Qwen/Qwen2.5-0.5B-Instruct, Qwen/Qwen2.5-1.5B, Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-3B, Qwen/Qwen2.5-3B-Instruct, Qwen/Qwen2.5-7B, Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-14B, Qwen/Qwen2.5-14B-Instruct, Qwen/Qwen2.5-32B, Qwen/Qwen2.5-32B-Instruct, |
| Qwen2.5-Math | Qwen/Qwen2.5-Math-1.5B, Qwen/Qwen2.5-Math-1.5B-Instruct, Qwen/Qwen2.5-Math-7B, Qwen/Qwen2.5-Math-7B-Instruct |
| Qwen2.5-Coder | Qwen/Qwen2.5-Coder-1.5B, Qwen/Qwen2.5-Coder-1.5B-Instruct, Qwen/Qwen2.5-Coder-7B, Qwen/Qwen2.5-Coder-7B-Instruct |
- src (list(str)): 经过 chat_template 处理后的 prompt 输入;或者根据需要自己拼接构造 prompt;
- tgt (list(str)): 标签内容;
{
"src": ["<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within <answer> </answer> tags. i.e., <answer> (1) Zoey is a knight\n(2) ... </answer>.\n<|im_end|>\n<|im_start|>user\nA very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 3 inhabitants: Michael, Zoey, and Ethan. Michael was heard saying, \"Ethan is a knight if and only if Michael is a knight\". \"Zoey is a knight or Ethan is a knight,\" Zoey mentioned. Ethan asserted: \"Michael is a knave if and only if Zoey is a knave\". So who is a knight and who is a knave?\n<|im_end|>\n<|im_start|>assistant\n<think>"],
"tgt": ["(1) Michael is a knight\n(2) Zoey is a knight\n(3) Ethan is a knight"]
}我们提供了一版使用 Qwen/Qwen2.5-7B-Instruct-1M 的chat template预处理后的KK 数据集。
wget https://paddlenlp.bj.bcebos.com/datasets/examples/ppo-kk.tgz && tar zxf ppo-kk.tgz
我们采用的配置文件放置在llm/config/qwen/grpo_argument.yaml中,同时我们提供了详细参数释义如下:
rl_algorithm: 使用的强化学习算法,支持grpo、reinforce_plus_plusactor_model_name_or_path: actor-model 和 reference-model 模型本地的模型路径reward_model_name_or_path: reward 模型的名称或本地路径use_rm_server: 是否使用 reward model server,设置为False时需要提供reward_model_name_or_pathreward_server: reward model server 的 URL 地址, 比如http://127.0.0.1:8731logging_dir: 日志保存的文件夹logging_steps: 训练日志打印的间隔步数output_dir: 模型参数保存目录report_to: 训练可视化工具,支持 "all"、"wandb"、"tensorboard"、"visualdl"、"none"wandb_http_proxy: 连接 wandb 使用的 HTTP 代理run_name: 实验名称train_datasets: 训练集路径eval_datasets: 验证集路径prompt_key: 数据集中 query 对应的字段名response_key: 数据集中 response 对应的字段名dataloader_drop_last: dataloader 是否丢弃最后不完整的 batchbalance_batch: 该参数用于指定是否在数据并行场景下,对批次内的 token 数量进行均衡分配。若设置为 True,系统将尝试在不同并行设备间平衡 token 的分布;若设置为 False(默认值),则不进行此类均衡操作。use_remove_padding: 此参数决定是否在训练过程中去除输入数据中的 padding 部分。启用该选项(设置为 True)可有效提高训练过程中有效 token 的占比,从而提升训练效率;若设置为 False(默认值),则保留输入数据中的 padding。tensor_parallel_degree: 张量并行度sequence_parallel: 是否启用序列并行sharding_parallel_degree: sharding 并行度sharding: 分片策略,支持 "stage1" 或 "stage2"sharding_parallel_config: sharding 并行配置pipeline_parallel_degree: 流水线并行度virtual_pp_degree: 虚拟流水线并行度max_prompt_len: 生成样本时的最大生成长度, max_length 调大会增加生成时间,并且增加显存占用。注意: max_dec_len + max_prompt_len 应当小于 max_seq_len。max_dec_len: 最大生成长度min_dec_len: 最小生成长度top_p: 生成解码超参数temperature: 生成解码超参数repetition_penalty: 生成解码超参数rollout_max_num_seqs: 单次推理可以处理的最大序列数rollout_quant_type: 量化类型,例如 "weight_only_int8"seed: 随机种子global_batch_size: 一次(一个 step)推理(rollout)采样的 prompt 数量global_mini_batch_size: actor model 更新一次参数训练的 prompt 数量rollout_n: 一个 prompt 采样的 response 数量update_iters: 同一批数据训练次数per_device_logprob_batch_size: 计算 log_probs 时,一个 batch 的样本数量per_device_reward_batch_size: critic model 计算 loss 与反向传播时,一个 batch 的的样本数量per_device_value_batch_size: critic model 前向计算 values 时,一个 batch 的的样本数量per_device_train_batch_size: actor model 计算 loss 与反向传播时,一个 batch 的样本数量num_train_epochs: 训练的 epoch 数max_length: 训练时的最大长度,应大于max_prompt_len和max_dec_len之和learning_rate: 学习率lr_scheduler_type: Actor 模型要使用的学习率调度策略。 (str, 可选, 默认为linear)weight_decay: AdamW 优化器的权重衰减adam_beta1: AdamW 优化器的 beta1adam_beta2: AdamW 优化器的 beta2adam_epsilon: AdamW 优化器的 epsilonmax_grad_norm: 梯度裁剪的最大值max_steps: 总的训练步数save_steps: 模型参数保存的间隔步数ignore_save_lr_and_optim: 是否忽略保存学习率和优化器状态kl_coeff: KL 惩罚系数kl_loss_coeff: KL Loss 系数pg_loss_coeff: 策略梯度损失系数entropy_coeff: entropy loss 系数clip_range_ratio: PPO-Clip 裁剪阈值clip_range_ratio_low: PPO-Clip 裁剪下限阈值clip_range_ratio_high: PPO-Clip 裁剪上限阈值clip_range_score: reward 的剪切范围,reward 会被限制在 [-clip_range_score, clip_range_score] 范围内clip_range_value: value 模型输出的剪切范围,value 会被限制在 [-clip_range_value, clip_range_value] 范围内normalize_reward: 是否使用 reward 标准化normalize_advantage: 是否使用 advantage 标准化use_fp32_compute: 是否使用 fp32 来计算 log_prob、reward、advantage 和 lossdo_eval: 是否进行评估per_device_eval_batch_size: 估 batch 大小evaluation_strategy: 评估策略,例如stepseval_steps: 模型评估的间隔步数use_flash_attention: 是否启用 FlashAttention-2,默认为 Falseuse_fused_rms_norm: 是否使用融合的 RMSNorm 算子,需安装 fused_lnrecompute: Actor 模型是否使用重计算策略,开启后可节省训练显存recompute_granularity: Actor 模型的重计算的粒度,可选项为core_attn和full.core_attn速度快但是显存占用,full速度慢但是显存占用低bf16: 使用 bfloat16 精度进行模型训练和推理。fp16_opt_level: float16 精度训练模式,O2表示纯 float16 训练amp_custom_black_list: 自定义 AMP 黑名单amp_custom_white_list: 自定义 AMP 白名单
cd your_PaddleNLP_path/llm/alignment/rl# 启动 reward server
python reward_server.pyexport PYTHONPATH=your_PaddleNLP_path/:$PYTHONPATH
export PYTHONPATH=your_PaddleNLP_path/llm:$PYTHONPATH
export FLAGS_set_to_1d=False
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_dataloader_use_file_descriptor=False
export HF_DATASETS_DOWNLOAD_TIMEOUT=1
export FLAGS_gemm_use_half_precision_compute_type=False
export FLAGS_force_cublaslt_no_reduced_precision_reduction=True
export FLAGS_mla_use_tensorcore=0
export FLAGS_cascade_attention_max_partition_size=2048
python -u -m paddle.distributed.launch --devices "0,1,2,3" run_rl.py ../../config/qwen/grpo_argument.yaml
# QWEN32B 2k prompt + 30k response 9台8x80G 显卡训练命令如下:
# python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_rl.py ../../config/qwen/grpo_32b_argument.yaml我们提供根据上述脚本可复现的wandb 日志。
cd your_PaddleNLP_path/llm/alignment/rl# 启动 reward server
python reward_server.pyexport PYTHONPATH=your_PaddleNLP_path/:$PYTHONPATH
export PYTHONPATH=your_PaddleNLP_path/llm:$PYTHONPATH
export FLAGS_set_to_1d=False
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_dataloader_use_file_descriptor=False
export HF_DATASETS_DOWNLOAD_TIMEOUT=1
export FLAGS_gemm_use_half_precision_compute_type=False
export FLAGS_force_cublaslt_no_reduced_precision_reduction=True
export FLAGS_mla_use_tensorcore=0
export FLAGS_cascade_attention_max_partition_size=2048
python -u -m paddle.distributed.launch --devices "0,1,2,3" run_rl.py ../../config/qwen/reinforce_plus_plus_argument.yaml我们提供根据上述脚本可复现的wandb 日志。
在grpo_argument.yaml和reinforce_plus_plus_argument.yaml中设置的输出目录为"logging_dir": "vdl_log", 可以通过以下命令查看训练过程
visualdl --logdir vdl_log --host 0.0.0.0也支持 wandb 等多种监控,可设置"logging_dir": "wandb",需要提前安装好 wandb 依赖并登录。