CoLA任务RL训练代码库
[English|中文]
本仓库主要聚焦于利用Qwen3系列模型,通过强化学习(Reinforcement Learning, RL)技术,完成GLUE基准中的CoLA(Corpus of Linguistic Acceptability)子任务的句子可接受性分类。代码实现了数据预处理、模型训练和评估全流程,方便快速上手与复现相关研究。
[25/09/08]整理主要结果,上传wandb模型训练记录。
[25/07/23]修复数据准备错误,新增SFT数据准备代码和训练脚本(基于LLaMA-Factory框架),新增REMAX、DAPO训练脚本,新增Text Classification代码,新增DeepSeek R1 0528蒸馏COLA数据集。
[25/06/22]完成数据处理流程、模型GRPO训练脚本(基于verl框架)和文档编写
- 对比指标Matthews相关系数(MCC)
- 提示词:
Decide whether the following sentence is grammatically acceptable or not. If it is grammatically correct, answer "acceptable". If not, answer "unacceptable". Only output "acceptable" or "unacceptable", and do not output any other information.
Sentence: {sentence}
Your answer:
| Model | Fine-tuning method | 验证集 | 测试集(kaggle) |
|---|---|---|---|
| Qwen3-0.6B | - | 0.223 | 待测试 |
| DeepSeek V3 0324 | - | 0.726 | 待测试 |
| DeepSeek R1 0120 | - | 0.636 | 待测试 |
| DeepSeek R1 0528 | - | 0.658 | 待测试 |
| Qwen3-1.7B-Remax | Remax (RL) | 0.658 | 待测试 |
| Qwen3-1.7B-GRPO | GRPO (RL) | 0.669 | 待测试 |
| Qwen3-1.7B-SFT-E1-GRPO | SFT + GRPO (RL) | 0.702 | 待测试 |
| Bert-base | CLS | 0.548 | 待测试 |
| Qwen3-0.6B-CLS | CLS | 0.610 | 待测试 |
| Qwen3-1.7B-SFT | SFT | 0.657 | 待测试 |
| Qwen3-0.6B-SFT | SFT | 0.598 | 待测试 |
Note
SFT-E1-GRPO后缀表示模型先进行1个Epoch的SFT后再进行GRPO。
CLS后缀表示模型输出头变为分类头。
- 各模型训练记录wandb log
Tip
参阅文档。
- 从魔搭社区或Huggingface下载
Qwen3系列模型到model文件夹下。
- 修改脚本
run_grpo_qwen3_0.6b.sh,修改wandb api key、工作目录和训练GPU编号。 - 启动训练:
bash run_grpo_qwen3_0.6b.sh- 对比不同RL算法对CoLA分类的效果。
- 对比不同参数量模型对CoLA分类的效果。
- 上传
wandb报告。
