Skip to content

rank-Yu/rwkv-state-tuning-guide

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

RWKV7 State Tuning 教程

State Tuning 是什么?

RWKV 是纯 RNN,因此可以做 transformer 难以做到的事情。作为 RNN 有固定大小的 state,所以微调 RWKV 的初始 state,就相当于最彻底的 prompt tuning,甚至可以用于 alignment,因为迁移能力很强。

本文的 State tuning 方法来自 RWKV 社区微调项目 RWKV-PEFT 。

开始之前,请确保你拥有一个 Linux 工作区,以及支持 CUDA 的 NVIDIA 显卡。

整理训练数据

收集 jsonl 格式训练数据

要 state tuning 微调 RWKV 模型,需要使用收集适合训练 RWKV 的数据(jsonl 格式)。

我们以 liumindmind/NekoQA-10K 数据集为例,如图:

该数据集共有 10066 条样本,如图:

将该数据集的 NekoQA-10K.json 文件下载到本地:

写代码将 NekoQA-10K.json 文件转换成对话文本格式的 jsonl 文件:

import json

# 请在此修改文件名
input_filename = "NekoQA-10K.json"
output_filename = "nekoqa_10k_formatted.jsonl"

def convert_local_json_to_jsonl(input_file, output_file):
    try:
        # 读取本地 JSON 文件
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        count = 0
        with open(output_file, 'w', encoding='utf-8') as out_f:
            for item in data:
                # 获取数据
                instr = item.get("instruction", "").strip()
                out = item.get("output", "").strip()
                
                if not instr and not out:
                    continue
                
                # 拼接成对话文本格式
                formatted_text = f"User: {instr}\n\nAssistant: {out}"
                
                # 写入 JSONL
                json_line = json.dumps({"text": formatted_text}, ensure_ascii=False)
                out_f.write(json_line + "\n")
                count += 1
        
        print(f"转换完成!已处理 {count} 条数据,保存至: {output_file}")

    except FileNotFoundError:
        print(f"错误:找不到文件 {input_file},请确认文件名和路径是否正确。")
    except Exception as e:
        print(f"发生错误: {e}")

if __name__ == "__main__":
    convert_local_json_to_jsonl(input_filename, output_filename)

配置训练环境

请参考RWKV 微调环境配置板块,配置 Conda 等训练环境。

克隆仓库并安装依赖

在 Linux 或 WSL 中,使用 git 命令克隆 RWKV-PEFT 仓库:

git clone https://github.com/JL-er/RWKV-PEFT.git
# 如果 GitHub 无法链接,请使用以下国内仓库:
git clone https://gitee.com/rwkv-vibe/RWKV-PEFT.git

克隆完成后,使用 cd RWKV-PEFT 命令进入 RWKV-PEFT 目录。并运行以下命令,安装项目所需依赖:

pip install -r requirements.txt

下载 rwkv7-1.5b 模型

这里下载 rwkv7-g1b-1.5b-20251202-ctx8192.pth:

修改训练参数

使用任意文本编辑器(如 vscode)打开 RWKV-PEFT/scripts 目录的 state tuning.sh 文件,修改训练参数,进而控制微调的训练过程和训练效果:

load_model="/root/gpufree-data/RWKV-PEFT/rwkv7-g1b-1.5b-20251202-ctx8192.pth"
proj_dir="/root/gpufree-data/RWKV-PEFT"
data_file="/root/gpufree-data/data/nekoqa_10k_formatted.jsonl"

n_layer=24
n_embd=2048

micro_bsz=32 # batch size,根据显存情况调整
epoch_save=1
epoch_steps=10066 # 将步长设置为数据集大小
ctx_len=500 # 上下文长度,根据显存情况调整

devices=1 # 显卡数量,根据实际情况修改

python train.py --load_model $load_model \
--proj_dir $proj_dir --data_file $data_file \
--vocab_size 65536 \
--data_type jsonl \
--n_layer $n_layer --n_embd $n_embd \
--ctx_len $ctx_len --micro_bsz $micro_bsz \
--epoch_steps $epoch_steps --epoch_count 6 --epoch_save $epoch_save \
--lr_init 5e-5 --lr_final 5e-6 \
--accelerator gpu --precision bf16 \
--devices $devices --strategy deepspeed_stage_1 --grad_cp 1 \
--my_testing "x070" \
--peft state --op fla \
--wandb NekoQA-10K # 使用wandb记录训练过程

开始训练

在 RWKV-PEFT 目录,运行 sh scripts/state tuning.sh 命令,开启 state tuning 。

正常开始训练后,应当是如下画面:

训练完毕后,应当可以在输出文件夹中找到训练好的 state 文件(.pth 格式)和训练日志(.txt 文件)。

我们可以在 wandb 看到损失函数的变化情况:

使用 state 文件

获得 state 文件后,我们用 RWKV Runner 工具给模型挂载 state 文件(RWKV教程详见这里)。

state 文件需要配合基底 RWKV 模型,才能发挥其效果。在 RWKV Runner 中,你可以按照以下步骤使用 state 文件:

  • 启动一个 RWKV 模型
  • 在配置页面选择与模型尺寸对应的 state
  • 点击 保存配置 按钮

点击保存后即可实时更新 state ,无需重新启动 RWKV 模型。

我们的数据基于 猫娘对话数据集,训练出来的 State 文件效果如下:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors