Skip to content

[BUG] deepspeed-chat bloom training error, raise RuntimeError "still have inflight params " after 14 steps training of step3 with offload option turned on #591

Open
@DZ9

Description

@DZ9

Describe the bug
I'm traning a bloom model in step3 using deepspeed-chat, with offload option turned on, after 14 steps training, it raised the following error(see in logs bleow). I trained it on a single node with 8 gpus.I ran this training 3 times and it all collapsed after about 1 hour training. I turnd the pytorch version from 1.12 to 1.13, and still got the same error.

the full training command is:
['/home/user/bin/python', '-u', 'main.py', '--local_rank=7', '--data_path', 'openai/webgpt_comparisons', '--data_split', '2,4,4', '--actor_model_name_or_path', 'yuanzhoulvpi/chinese_bloom_7b_chat_v3', '--critic_model_name_or_path', '/mnt/workspace/workgroup/res/weights/chinese_bloom_560m/', '--num_padding_at_beginning', '1', '--per_device_train_batch_size', '8', '--per_device_mini_train_batch_size', '8', '--generation_batch_numbers', '1', '--ppo_epochs', '1', '--max_answer_seq_len', '256', '--max_prompt_seq_len', '256', '--ppo_epochs', '1', '--actor_learning_rate', '5e-4', '--critic_learning_rate', '5e-6', '--num_train_epochs', '1', '--lr_scheduler_type', 'cosine', '--gradient_accumulation_steps', '16', '--num_warmup_steps', '100', '--deepspeed', '--seed', '1234', '--actor_zero_stage', '3', '--critic_zero_stage', '3', '--actor_lora_dim', '128', '--actor_gradient_checkpointing', '--critic_gradient_checkpointing', '--disable_actor_dropout', '--offload']

Log output

-------------------------------------------------------------------------------------
epoch: 0|step: 13|ppo_ep: 1|act_loss: 2.9296875|cri_loss: 11.4453125|unsuper_loss: 0.0
average reward score: -6.86328125
-------------------------------------------------------------------------------------
epoch: 0|step: 14|ppo_ep: 1|act_loss: 4.75390625|cri_loss: inf|unsuper_loss: 0.0
average reward score: -6.7890625
-------------------------------------------------------------------------------------
[2023-06-09 16:12:59,766] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1
[2023-06-09 16:13:00,631] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1
epoch: 0|step: 15|ppo_ep: 1|act_loss: 3.587890625|cri_loss: 15.0390625|unsuper_loss: 0.0
average reward score: -7.5546875
-------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
    main()    
main()
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
    main()
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
    main()
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
    main()
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
    main()
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
    out = trainer.generate_experience(batch_prompt['prompt'],
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
    main()
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
    out = trainer.generate_experience(batch_prompt['prompt'],
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
        out = trainer.generate_experience(batch_prompt['prompt'],out = trainer.generate_experience(batch_prompt['prompt'],

  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
    out = trainer.generate_experience(batch_prompt['prompt'],
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
    output = self.actor_model(seq, attention_mask=attention_mask)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    out = trainer.generate_experience(batch_prompt['prompt'],
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
    out = trainer.generate_experience(batch_prompt['prompt'],
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
    return forward_call(*input, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    output = self.actor_model(seq, attention_mask=attention_mask)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    loss = self.module(*inputs, **kwargs)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
    output = self.actor_model(seq, attention_mask=attention_mask)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
    loss = self.module(*inputs, **kwargs)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[-0.0315, -0.0004,  0.0332,  ..., -0.0090, -0.0150,  0.0110],
        [ 0.0378,  0.0022,  0.0147,  ..., -0.0371, -0.0123,  0.0025],
        [-0.0126,  0.0239, -0.0034,  ...,  0.0175, -0.0101, -0.0043],
        ...,
        [-0.0102,  0.0270,  0.0225,  ...,  0.0334,  0.0400,  0.0259],
        [ 0.0195,  0.0212, -0.0058,  ..., -0.0425, -0.0085,  0.0098],
        [ 0.0231, -0.0060, -0.0216,  ..., -0.0052,  0.0019, -0.0039]],
       device='cuda:4', dtype=torch.float16, requires_grad=True)>]
    return forward_call(*input, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    output = self.actor_model(seq, attention_mask=attention_mask)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[-0.0242, -0.0339, -0.0457,  ..., -0.0188, -0.0544, -0.0222],
        [-0.0160,  0.0058,  0.0139,  ...,  0.0027, -0.0168,  0.0247],
        [ 0.0043,  0.0028, -0.0056,  ..., -0.0007, -0.0203,  0.0170],
        ...,
        [ 0.0199, -0.0249, -0.0112,  ...,  0.0168,  0.0315, -0.0017],
        [ 0.0067, -0.0036,  0.0208,  ...,  0.0236,  0.0021, -0.0154],
        [ 0.0233,  0.0156, -0.0254,  ...,  0.0189,  0.0096, -0.0142]],
       device='cuda:7', dtype=torch.float16, requires_grad=True)>]
        return forward_call(*input, **kwargs)loss = self.module(*inputs, **kwargs)

  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
        output = self.actor_model(seq, attention_mask=attention_mask)loss = self.module(*inputs, **kwargs)

  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[-0.0023, -0.0052, -0.0221,  ...,  0.0148,  0.0240, -0.0168],
        [ 0.0101, -0.0024, -0.0593,  ...,  0.0103, -0.0342, -0.0173],
        [-0.0229, -0.0079, -0.0126,  ...,  0.0009, -0.0197, -0.0182],
        ...,
        [-0.0214, -0.0104,  0.0034,  ..., -0.0337, -0.0011,  0.0208],
        [ 0.0089, -0.0415, -0.0085,  ...,  0.0112,  0.0016, -0.0112],
        [-0.0104, -0.0278, -0.0075,  ...,  0.0569, -0.0123,  0.0347]],
       device='cuda:0', dtype=torch.float16, requires_grad=True)>]    
output = self.actor_model(seq, attention_mask=attention_mask)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    return forward_call(*input, **kwargs)
      File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
        ret_val = func(*args, **kwargs)return forward_call(*input, **kwargs)

  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[-0.0608, -0.0464, -0.0111,  ...,  0.0085, -0.0047,  0.0221],
        [ 0.0033,  0.0026,  0.0113,  ...,  0.0049,  0.0031,  0.0104],
        [ 0.0053, -0.0136,  0.0095,  ..., -0.0295, -0.0013, -0.0087],
        ...,
        [ 0.0004, -0.0178, -0.0137,  ...,  0.0059,  0.0152, -0.0002],
        [ 0.0454,  0.0233, -0.0053,  ..., -0.0124, -0.0214, -0.0270],
        [ 0.0056, -0.0488,  0.0085,  ..., -0.0129,  0.0289,  0.0110]],
       device='cuda:6', dtype=torch.float16, requires_grad=True)>]
    loss = self.module(*inputs, **kwargs)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    loss = self.module(*inputs, **kwargs)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    output = self.actor_model(seq, attention_mask=attention_mask)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[-0.0023, -0.0052, -0.0221,  ...,  0.0148,  0.0240, -0.0168],
        [ 0.0101, -0.0024, -0.0593,  ...,  0.0103, -0.0342, -0.0173],
        [-0.0229, -0.0079, -0.0126,  ...,  0.0009, -0.0197, -0.0182],
        ...,
        [ 0.0050,  0.0030, -0.0137,  ..., -0.0025, -0.0178, -0.0214],
        [-0.0206, -0.0193,  0.0143,  ...,  0.0238,  0.0104, -0.0435],
        [-0.0260,  0.0138, -0.0020,  ..., -0.0036,  0.0020,  0.0157]],
       device='cuda:3', dtype=torch.float16, requires_grad=True)>]    
ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
    return forward_call(*input, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[ 0.0042,  0.0065,  0.0332,  ..., -0.0122, -0.0067, -0.0110],
        [ 0.0281,  0.0021, -0.0168,  ..., -0.0114,  0.0029, -0.0227],
        [ 0.0153,  0.0096,  0.0154,  ...,  0.0011,  0.0090,  0.0164],
        ...,
        [ 0.0136,  0.0050,  0.0221,  ...,  0.0150, -0.0114,  0.0267],
        [ 0.0095,  0.0223, -0.0248,  ..., -0.0339,  0.0020, -0.0244],
        [-0.0244, -0.0022,  0.0026,  ...,  0.0262,  0.0250, -0.0292]],
       device='cuda:1', dtype=torch.float16, requires_grad=True)>]
    loss = self.module(*inputs, **kwargs)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[-0.0023, -0.0052, -0.0221,  ...,  0.0148,  0.0240, -0.0168],
        [ 0.0101, -0.0024, -0.0593,  ...,  0.0103, -0.0342, -0.0173],
        [-0.0229, -0.0079, -0.0126,  ...,  0.0009, -0.0197, -0.0182],
        ...,
        [ 0.0060,  0.0171,  0.0018,  ...,  0.0142, -0.0076,  0.0075],
        [ 0.0186, -0.0070, -0.0282,  ..., -0.0173,  0.0253,  0.0181],
        [ 0.0403,  0.0181,  0.0181,  ...,  0.0267,  0.0023, -0.0029]],
       device='cuda:2', dtype=torch.float16, requires_grad=True)>]
Traceback (most recent call last):
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 520, in <module>
    main()
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 428, in main
    out = trainer.generate_experience(batch_prompt['prompt'],
  File "/mnt/workspace/workgroup/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 105, in generate_experience
    output = self.actor_model(seq, attention_mask=attention_mask)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1736, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/user/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1211, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/user/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 329, in _end_of_forward_hook
    self.get_param_coordinator(training=False).reset_step()
  File "/home/user/lib/python3.9/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 185, in reset_step
    raise RuntimeError(f"still have inflight params "
RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
tensor([[-0.0023, -0.0052, -0.0221,  ...,  0.0148,  0.0240, -0.0168],
        [ 0.0101, -0.0024, -0.0593,  ...,  0.0103, -0.0342, -0.0173],
        [-0.0229, -0.0079, -0.0126,  ...,  0.0009, -0.0197, -0.0182],
        ...,
        [-0.0042, -0.0012,  0.0121,  ...,  0.0123, -0.0023,  0.0014],
        [ 0.0121, -0.0019, -0.0132,  ...,  0.0031, -0.0141,  0.0038],
        [-0.0014,  0.0193,  0.0042,  ...,  0.0034, -0.0004, -0.0056]],
       device='cuda:5', dtype=torch.float16, requires_grad=True)>]
[2023-06-09 16:16:56,690] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110262
[2023-06-09 16:16:59,824] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110263
[2023-06-09 16:16:59,825] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110264
[2023-06-09 16:16:59,827] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110265
[2023-06-09 16:16:59,827] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110266
[2023-06-09 16:16:59,828] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110267
[2023-06-09 16:16:59,829] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110268
[2023-06-09 16:16:59,830] [INFO] [launch.py:314:sigkill_handler] Killing subprocess 110269
[2023-06-09 16:16:59,832] [ERROR] [launch.py:320:sigkill_handler] ['/home/user/bin/python', '-u', 'main.py', '--local_rank=7', '--data_path', 'openai/webgpt_comparisons', '--data_split', '2,4,4', '--actor_model_name_or_path', 'yuanzhoulvpi/chinese_bloom_7b_chat_v3', '--critic_model_name_or_path', '/mnt/workspace/workgroup/res/weights/chinese_bloom_560m/', '--num_padding_at_beginning', '1', '--per_device_train_batch_size', '8', '--per_device_mini_train_batch_size', '8', '--generation_batch_numbers', '1', '--ppo_epochs', '1', '--max_answer_seq_len', '256', '--max_prompt_seq_len', '256', '--ppo_epochs', '1', '--actor_learning_rate', '5e-4', '--critic_learning_rate', '5e-6', '--num_train_epochs', '1', '--lr_scheduler_type', 'cosine', '--gradient_accumulation_steps', '16', '--num_warmup_steps', '100', '--deepspeed', '--seed', '1234', '--actor_zero_stage', '3', '--critic_zero_stage', '3', '--actor_lora_dim', '128', '--actor_gradient_checkpointing', '--critic_gradient_checkpointing', '--disable_actor_dropout', '--offload'] exits with return code = 1

ds_report output

Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/pai/lib/python3.9/site-packages/torch']
torch version .................... 1.13.0
deepspeed install path ........... ['/home/pai/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.9.3, unknown, unknown
torch cuda version ............... 11.6
torch hip version ................ None
nvcc version ..................... 11.3
deepspeed wheel compiled w. ...... torch 1.12, cuda 11.3

System info (please complete the following information):

  • OS: ubuntu20.04
  • 8*v100 32G
  • pytorch:1.13
  • py39

How to repreduce
all the datasets and models I use are open published on huggingface, one should start the same experiment just using the same start command.

Metadata

Metadata

Labels

deespeed chatDeepSpeed Chatnew-configA modified config from the given example

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions