Skip to content

single gpu 6.7b lora CUDA OOM with A6000 48G #330

Open
@HyeongminMoon

Description

@HyeongminMoon

I am trying to run DeepSpeed-Chat Example with single gpu, Nvidia A6000 48G.

I could run all 3 steps well using 1.3b example.
But when I run single_gpu/run_6.7b_lora.sh, I got CUDA Out Of Memory error at step3.
Step1 & step2 were run well.

Even after I minimized configurations, I still get OOM.
Here is my run_6.7b_lora.sh config:

ACTOR_ZERO_STAGE="--actor_zero_stage 0"
CRITIC_ZERO_STAGE="--critic_zero_stage 0"
ACTOR_MODEL_PATH=../step1_supervised_finetuning/output
CRITIC_MODEL_PATH=../step2_reward_model_finetuning/output

OUTPUT="./output"

Num_Padding_at_Beginning=1 # this is model related

Actor_Lr=5e-4
Critic_Lr=5e-6

mkdir -p $OUTPUT

deepspeed --num_gpus 1 main.py \
   --data_path Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets openai/webgpt_comparisons stanfordnlp/SHP \
   --data_split 2,4,4 \
   --actor_model_name_or_path $ACTOR_MODEL_PATH \
   --critic_model_name_or_path $CRITIC_MODEL_PATH \
   --num_padding_at_beginning 1 \
   --per_device_train_batch_size 1 \
   --per_device_mini_train_batch_size 1 \
   --generation_batch_numbers 1 \
   --ppo_epochs 1 \
   --max_answer_seq_len 128 \
   --max_prompt_seq_len 128 \
   --ppo_epochs 1 \
   --actor_learning_rate ${Actor_Lr} \
   --critic_learning_rate ${Critic_Lr} \
   --actor_weight_decay 0.1 \
   --critic_weight_decay 0.1 \
   --num_train_epochs 1 \
   --lr_scheduler_type cosine \
   --gradient_accumulation_steps 8 \
   --num_warmup_steps 100 \
   --deepspeed --seed 1234 \
   ${ACTOR_ZERO_STAGE} \
   ${CRITIC_ZERO_STAGE} ${OFFLOAD}\
   --actor_lora_dim 128 \
   --actor_gradient_checkpointing \
   --critic_gradient_checkpointing \
   --enable_hybrid_engine \
   --output_dir $OUTPUT \
    &> $OUTPUT/training.log

And I got OOM especially at gradient_accumulation_steps.
Here is my error point:

...
******************[end] Initialized Reward Model [end] (duration: 2.67s)******************
***** Running training *****
Beginning of Epoch 1/1, Total Generation Batches 264292
------------------------------------------------------
Free memory : 7.042725 (GigaBytes)  
Total memory: 47.544312 (GigaBytes)  
Requested memory: 0.304688 (GigaBytes) 
Setting maximum total tokens (input + output) to 512 
WorkSpace: 0x7f8f62000000 
------------------------------------------------------
/home/ados/anaconda3/envs/DeepSpeed/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
epoch: 0|step: 0|ppo_ep: 1|act_loss: 0.2626953125|cri_loss: 0.13720703125|unsuper_loss: 0.0
average reward score: 2.8828125
-------------------------------------------------------------------------------------
|E2E latency=3.87s |Gather latency=0.00s (0.00%) |Generate time=3.40s (87.89%) |Training time=0.25s (6.47%) |Others=0.22 (5.65%)|CurSamplesPerSec=0.26 |AvgSamplesPerSec=0.26
epoch: 0|step: 1|ppo_ep: 1|act_loss: 0.72265625|cri_loss: 0.4296875|unsuper_loss: 0.0
average reward score: 3.65625
-------------------------------------------------------------------------------------
|E2E latency=3.31s |Gather latency=0.00s (0.00%) |Generate time=2.87s (86.86%) |Training time=0.26s (7.71%) |Others=0.18 (5.43%)|CurSamplesPerSec=0.30 |AvgSamplesPerSec=0.28
epoch: 0|step: 2|ppo_ep: 1|act_loss: -0.394287109375|cri_loss: -0.1787109375|unsuper_loss: 0.0
average reward score: 4.98046875
-------------------------------------------------------------------------------------
|E2E latency=3.30s |Gather latency=0.00s (0.00%) |Generate time=2.87s (87.01%) |Training time=0.25s (7.64%) |Others=0.18 (5.35%)|CurSamplesPerSec=0.30 |AvgSamplesPerSec=0.29
epoch: 0|step: 3|ppo_ep: 1|act_loss: -0.053619384765625|cri_loss: -0.0135498046875|unsuper_loss: 0.0
average reward score: 5.6015625
-------------------------------------------------------------------------------------
|E2E latency=3.29s |Gather latency=0.00s (0.00%) |Generate time=2.86s (86.82%) |Training time=0.25s (7.71%) |Others=0.18 (5.47%)|CurSamplesPerSec=0.30 |AvgSamplesPerSec=0.29
epoch: 0|step: 4|ppo_ep: 1|act_loss: 0.49560546875|cri_loss: 0.264404296875|unsuper_loss: 0.0
average reward score: 1.3955078125
-------------------------------------------------------------------------------------
|E2E latency=3.29s |Gather latency=0.00s (0.00%) |Generate time=2.86s (86.89%) |Training time=0.25s (7.71%) |Others=0.18 (5.40%)|CurSamplesPerSec=0.30 |AvgSamplesPerSec=0.29
epoch: 0|step: 5|ppo_ep: 1|act_loss: -0.26171875|cri_loss: -0.1119384765625|unsuper_loss: 0.0
average reward score: 4.09765625
-------------------------------------------------------------------------------------
|E2E latency=3.29s |Gather latency=0.00s (0.00%) |Generate time=2.86s (86.94%) |Training time=0.25s (7.67%) |Others=0.18 (5.38%)|CurSamplesPerSec=0.30 |AvgSamplesPerSec=0.29
epoch: 0|step: 6|ppo_ep: 1|act_loss: -0.13427734375|cri_loss: -0.05322265625|unsuper_loss: 0.0
average reward score: 3.005859375
-------------------------------------------------------------------------------------
|E2E latency=3.29s |Gather latency=0.00s (0.00%) |Generate time=2.86s (86.86%) |Training time=0.26s (7.75%) |Others=0.18 (5.39%)|CurSamplesPerSec=0.30 |AvgSamplesPerSec=0.30
Traceback (most recent call last):
  File "/home/workspaces/mohomin/git_temp/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 516, in <module>
    main()
  File "/home/workspaces/mohomin/git_temp/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 439, in main
    actor_loss, critic_loss = trainer.train_rlhf(exp_data)
  File "/home/workspaces/mohomin/git_temp/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 172, in train_rlhf
    self.actor_model.backward(actor_loss)
  File "/home/ados/anaconda3/envs/DeepSpeed/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ados/anaconda3/envs/DeepSpeed/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1842, in backward
    self.allreduce_gradients()
  File "/home/ados/anaconda3/envs/DeepSpeed/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ados/anaconda3/envs/DeepSpeed/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1772, in allreduce_gradients
    self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
  File "/home/ados/anaconda3/envs/DeepSpeed/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 2273, in buffered_allreduce_fallback
    non_expert_grads, expert_grads = self._get_gradients_for_reduction()
  File "/home/ados/anaconda3/envs/DeepSpeed/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 2229, in _get_gradients_for_reduction
    param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 47.54 GiB total capacity; 45.81 GiB already allocated; 85.75 MiB free; 45.95 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[2023-04-17 16:44:36,269] [INFO] [launch.py:428:sigkill_handler] Killing subprocess 5403
[2023-04-17 16:44:36,270] [ERROR] [launch.py:434:sigkill_handler] ['/home/ados/anaconda3/envs/DeepSpeed/bin/python', '-u', 'main.py', '--local_rank=0', '--data_path', 'Dahoas/rm-static', 'Dahoas/full-hh-rlhf', 'Dahoas/synthetic-instruct-gptj-pairwise', 'yitingxie/rlhf-reward-datasets', 'openai/webgpt_comparisons', 'stanfordnlp/SHP', '--data_split', '2,4,4', '--actor_model_name_or_path', '../step1_supervised_finetuning/output', '--critic_model_name_or_path', '../step2_reward_model_finetuning/output', '--num_padding_at_beginning', '1', '--per_device_train_batch_size', '1', '--per_device_mini_train_batch_size', '1', '--generation_batch_numbers', '1', '--ppo_epochs', '1', '--max_answer_seq_len', '128', '--max_prompt_seq_len', '128', '--ppo_epochs', '1', '--actor_learning_rate', '5e-4', '--critic_learning_rate', '5e-6', '--actor_weight_decay', '0.1', '--critic_weight_decay', '0.1', '--num_train_epochs', '1', '--lr_scheduler_type', 'cosine', '--gradient_accumulation_steps', '8', '--num_warmup_steps', '100', '--deepspeed', '--seed', '1234', '--actor_zero_stage', '0', '--critic_zero_stage', '0', '--actor_lora_dim', '128', '--actor_gradient_checkpointing', '--critic_gradient_checkpointing', '--enable_hybrid_engine', '--output_dir', './output'] exits with return code = 1

Environments

  • deepspeed-0.9.1+cc67f22f
  • CUDA 11.7
  • torch 2.0.0
  • python 3.9.16

I also tried with --only_optimize_lora but got a same error.
Is there any possible way to run 6.7b_lora model on 48G single gpu?
Thank you for any help.

Metadata

Metadata

Assignees

Labels

deespeed chatDeepSpeed ChatquestionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions