Skip to content

GPU memory requirements of train_ppo #4

@skepsun

Description

@skepsun

I tried to perform sft and ppo on a 3b llama model on 4xA100 (40g). In ppo stage (using bloom-1b7-rm as reward model), I always got OOM errors even after using very small batch sizes (1 for actor, 4 for reward model). I'm curious about how much GPU memory should I prepare for training ppo on a 3b model?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions