I tried to perform sft and ppo on a 3b llama model on 4xA100 (40g). In ppo stage (using bloom-1b7-rm as reward model), I always got OOM errors even after using very small batch sizes (1 for actor, 4 for reward model). I'm curious about how much GPU memory should I prepare for training ppo on a 3b model?