Skip to content

rl训练超参的含义 #1

@lsm2842035890

Description

@lsm2842035890

我对这几个参数不太懂,我想要知道这些参数设置完训练的逻辑 在多卡上加快rl训练
num_generations: 8
num_iterations: 2
per_device_eval_batch_size: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 2

按照这个配置和代码,观察到每一个step是每张显卡加载相同的2个sample,如果是四张卡,那就是1个step rollout8个回答,有8个奖励值来计算优势;梯度是每2个step更新一次;这个过程迭代两次;(不知道对不对)

我换成8卡后,还是一个step rollout8个回答,那需要per_device_train_batch_size变成1,但是会出现oom显存溢出的问题
num_generations: 8
num_iterations: 2
per_device_eval_batch_size: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 2

我感觉我理解错这几个参数在训练的具体作用了,所以想请您指教一下~

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions