Skip to content

Missing key(s) in state_dict for bias in attention blocks #374

Open
@EikeKohl

Description

@EikeKohl

I am trying to run step 3 of the RLHF examples using a RewardModel checkpoint that I trained using step 2 of the examples. For every step, I used the provided sh scripts and only adjusted the model / data paths. Unfortunately, I encountered the following exception:

*******************[end] Initialized Ref Model [end] (duration: 0.59s)********************
************************[start] Initializing Critic Model [start] ************************
Traceback (most recent call last):
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 516, in <module>
    main()
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 385, in main
    rlhf_engine = DeepSpeedRLHFEngine(
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py", line 57, in __init__
    self.critic = self._init_critic(
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py", line 193, in _init_critic
    critic_model = create_critic_model(
  File "/home/ec2-user/SageMaker/deepspeedexamples-fork/applications/DeepSpeed-Chat/training/utils/model/model_utils.py", line 69, in create_critic_model
    critic_model.load_state_dict(
  File "/home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RewardModel:
        Missing key(s) in state_dict: "rwtransformer.h.0.attn.bias", "rwtransformer.h.0.attn.masked_bias", "rwtransformer.h.1.attn.bias", "rwtransformer.h.1.attn.masked_bias", "rwtransformer.h.2.attn.bias", "rwtransformer.h.2.attn.masked_bias", "rwtransformer.h.3.attn.bias", "rwtransformer.h.3.attn.masked_bias", "rwtransformer.h.4.attn.bias", "rwtransformer.h.4.attn.masked_bias", "rwtransformer.h.5.attn.bias", "rwtransformer.h.5.attn.masked_bias", "rwtransformer.h.6.attn.bias", "rwtransformer.h.6.attn.masked_bias", "rwtransformer.h.7.attn.bias", "rwtransformer.h.7.attn.masked_bias", "rwtransformer.h.8.attn.bias", "rwtransformer.h.8.attn.masked_bias", "rwtransformer.h.9.attn.bias", "rwtransformer.h.9.attn.masked_bias", "rwtransformer.h.10.attn.bias", "rwtransformer.h.10.attn.masked_bias", "rwtransformer.h.11.attn.bias", "rwtransformer.h.11.attn.masked_bias". 
[2023-04-20 11:58:58,807] [INFO] [launch.py:428:sigkill_handler] Killing subprocess 24065

These are the keys in model.named_parameters() Before training step 2:

['wte.weight', 'wpe.weight', 'h.0.ln_1.weight', 'h.0.ln_1.bias', 'h.0.attn.c_attn.weight', 'h.0.attn.c_attn.bias', 'h.0.attn.c_proj.weight', 'h.0.attn.c_proj.bias', 'h.0.ln_2.weight', 'h.0.ln_2.bias', 'h.0.mlp.c_fc.weight', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_proj.weight', 'h.0.mlp.c_proj.bias', 'h.1.ln_1.weight', 'h.1.ln_1.bias', 'h.1.attn.c_attn.weight', 'h.1.attn.c_attn.bias', 'h.1.attn.c_proj.weight', 'h.1.attn.c_proj.bias', 'h.1.ln_2.weight', 'h.1.ln_2.bias', 'h.1.mlp.c_fc.weight', 'h.1.mlp.c_fc.bias', 'h.1.mlp.c_proj.weight', 'h.1.mlp.c_proj.bias', 'h.2.ln_1.weight', 'h.2.ln_1.bias', 'h.2.attn.c_attn.weight', 'h.2.attn.c_attn.bias', 'h.2.attn.c_proj.weight', 'h.2.attn.c_proj.bias', 'h.2.ln_2.weight', 'h.2.ln_2.bias', 'h.2.mlp.c_fc.weight', 'h.2.mlp.c_fc.bias', 'h.2.mlp.c_proj.weight', 'h.2.mlp.c_proj.bias', 'h.3.ln_1.weight', 'h.3.ln_1.bias', 'h.3.attn.c_attn.weight', 'h.3.attn.c_attn.bias', 'h.3.attn.c_proj.weight', 'h.3.attn.c_proj.bias', 'h.3.ln_2.weight', 'h.3.ln_2.bias', 'h.3.mlp.c_fc.weight', 'h.3.mlp.c_fc.bias', 'h.3.mlp.c_proj.weight', 'h.3.mlp.c_proj.bias', 'h.4.ln_1.weight', 'h.4.ln_1.bias', 'h.4.attn.c_attn.weight', 'h.4.attn.c_attn.bias', 'h.4.attn.c_proj.weight', 'h.4.attn.c_proj.bias', 'h.4.ln_2.weight', 'h.4.ln_2.bias', 'h.4.mlp.c_fc.weight', 'h.4.mlp.c_fc.bias', 'h.4.mlp.c_proj.weight', 'h.4.mlp.c_proj.bias', 'h.5.ln_1.weight', 'h.5.ln_1.bias', 'h.5.attn.c_attn.weight', 'h.5.attn.c_attn.bias', 'h.5.attn.c_proj.weight', 'h.5.attn.c_proj.bias', 'h.5.ln_2.weight', 'h.5.ln_2.bias', 'h.5.mlp.c_fc.weight', 'h.5.mlp.c_fc.bias', 'h.5.mlp.c_proj.weight', 'h.5.mlp.c_proj.bias', 'h.6.ln_1.weight', 'h.6.ln_1.bias', 'h.6.attn.c_attn.weight', 'h.6.attn.c_attn.bias', 'h.6.attn.c_proj.weight', 'h.6.attn.c_proj.bias', 'h.6.ln_2.weight', 'h.6.ln_2.bias', 'h.6.mlp.c_fc.weight', 'h.6.mlp.c_fc.bias', 'h.6.mlp.c_proj.weight', 'h.6.mlp.c_proj.bias', 'h.7.ln_1.weight', 'h.7.ln_1.bias', 'h.7.attn.c_attn.weight', 'h.7.attn.c_attn.bias', 'h.7.attn.c_proj.weight', 'h.7.attn.c_proj.bias', 'h.7.ln_2.weight', 'h.7.ln_2.bias', 'h.7.mlp.c_fc.weight', 'h.7.mlp.c_fc.bias', 'h.7.mlp.c_proj.weight', 'h.7.mlp.c_proj.bias', 'h.8.ln_1.weight', 'h.8.ln_1.bias', 'h.8.attn.c_attn.weight', 'h.8.attn.c_attn.bias', 'h.8.attn.c_proj.weight', 'h.8.attn.c_proj.bias', 'h.8.ln_2.weight', 'h.8.ln_2.bias', 'h.8.mlp.c_fc.weight', 'h.8.mlp.c_fc.bias', 'h.8.mlp.c_proj.weight', 'h.8.mlp.c_proj.bias', 'h.9.ln_1.weight', 'h.9.ln_1.bias', 'h.9.attn.c_attn.weight', 'h.9.attn.c_attn.bias', 'h.9.attn.c_proj.weight', 'h.9.attn.c_proj.bias', 'h.9.ln_2.weight', 'h.9.ln_2.bias', 'h.9.mlp.c_fc.weight', 'h.9.mlp.c_fc.bias', 'h.9.mlp.c_proj.weight', 'h.9.mlp.c_proj.bias', 'h.10.ln_1.weight', 'h.10.ln_1.bias', 'h.10.attn.c_attn.weight', 'h.10.attn.c_attn.bias', 'h.10.attn.c_proj.weight', 'h.10.attn.c_proj.bias', 'h.10.ln_2.weight', 'h.10.ln_2.bias', 'h.10.mlp.c_fc.weight', 'h.10.mlp.c_fc.bias', 'h.10.mlp.c_proj.weight', 'h.10.mlp.c_proj.bias', 'h.11.ln_1.weight', 'h.11.ln_1.bias', 'h.11.attn.c_attn.weight', 'h.11.attn.c_attn.bias', 'h.11.attn.c_proj.weight', 'h.11.attn.c_proj.bias', 'h.11.ln_2.weight', 'h.11.ln_2.bias', 'h.11.mlp.c_fc.weight', 'h.11.mlp.c_fc.bias', 'h.11.mlp.c_proj.weight', 'h.11.mlp.c_proj.bias', 'ln_f.weight', 'ln_f.bias']

These are the keys of the trained RewardModel after loading it with torch.load():

['v_head.weight', 'rwtransformer.wte.weight', 'rwtransformer.wpe.weight', 'rwtransformer.h.0.ln_1.weight', 'rwtransformer.h.0.ln_1.bias', 'rwtransformer.h.0.attn.c_attn.weight', 'rwtransformer.h.0.attn.c_attn.bias', 'rwtransformer.h.0.attn.c_proj.weight', 'rwtransformer.h.0.attn.c_proj.bias', 'rwtransformer.h.0.ln_2.weight', 'rwtransformer.h.0.ln_2.bias', 'rwtransformer.h.0.mlp.c_fc.weight', 'rwtransformer.h.0.mlp.c_fc.bias', 'rwtransformer.h.0.mlp.c_proj.weight', 'rwtransformer.h.0.mlp.c_proj.bias', 'rwtransformer.h.1.ln_1.weight', 'rwtransformer.h.1.ln_1.bias', 'rwtransformer.h.1.attn.c_attn.weight', 'rwtransformer.h.1.attn.c_attn.bias', 'rwtransformer.h.1.attn.c_proj.weight', 'rwtransformer.h.1.attn.c_proj.bias', 'rwtransformer.h.1.ln_2.weight', 'rwtransformer.h.1.ln_2.bias', 'rwtransformer.h.1.mlp.c_fc.weight', 'rwtransformer.h.1.mlp.c_fc.bias', 'rwtransformer.h.1.mlp.c_proj.weight', 'rwtransformer.h.1.mlp.c_proj.bias', 'rwtransformer.h.2.ln_1.weight', 'rwtransformer.h.2.ln_1.bias', 'rwtransformer.h.2.attn.c_attn.weight', 'rwtransformer.h.2.attn.c_attn.bias', 'rwtransformer.h.2.attn.c_proj.weight', 'rwtransformer.h.2.attn.c_proj.bias', 'rwtransformer.h.2.ln_2.weight', 'rwtransformer.h.2.ln_2.bias', 'rwtransformer.h.2.mlp.c_fc.weight', 'rwtransformer.h.2.mlp.c_fc.bias', 'rwtransformer.h.2.mlp.c_proj.weight', 'rwtransformer.h.2.mlp.c_proj.bias', 'rwtransformer.h.3.ln_1.weight', 'rwtransformer.h.3.ln_1.bias', 'rwtransformer.h.3.attn.c_attn.weight', 'rwtransformer.h.3.attn.c_attn.bias', 'rwtransformer.h.3.attn.c_proj.weight', 'rwtransformer.h.3.attn.c_proj.bias', 'rwtransformer.h.3.ln_2.weight', 'rwtransformer.h.3.ln_2.bias', 'rwtransformer.h.3.mlp.c_fc.weight', 'rwtransformer.h.3.mlp.c_fc.bias', 'rwtransformer.h.3.mlp.c_proj.weight', 'rwtransformer.h.3.mlp.c_proj.bias', 'rwtransformer.h.4.ln_1.weight', 'rwtransformer.h.4.ln_1.bias', 'rwtransformer.h.4.attn.c_attn.weight', 'rwtransformer.h.4.attn.c_attn.bias', 'rwtransformer.h.4.attn.c_proj.weight', 'rwtransformer.h.4.attn.c_proj.bias', 'rwtransformer.h.4.ln_2.weight', 'rwtransformer.h.4.ln_2.bias', 'rwtransformer.h.4.mlp.c_fc.weight', 'rwtransformer.h.4.mlp.c_fc.bias', 'rwtransformer.h.4.mlp.c_proj.weight', 'rwtransformer.h.4.mlp.c_proj.bias', 'rwtransformer.h.5.ln_1.weight', 'rwtransformer.h.5.ln_1.bias', 'rwtransformer.h.5.attn.c_attn.weight', 'rwtransformer.h.5.attn.c_attn.bias', 'rwtransformer.h.5.attn.c_proj.weight', 'rwtransformer.h.5.attn.c_proj.bias', 'rwtransformer.h.5.ln_2.weight', 'rwtransformer.h.5.ln_2.bias', 'rwtransformer.h.5.mlp.c_fc.weight', 'rwtransformer.h.5.mlp.c_fc.bias', 'rwtransformer.h.5.mlp.c_proj.weight', 'rwtransformer.h.5.mlp.c_proj.bias', 'rwtransformer.h.6.ln_1.weight', 'rwtransformer.h.6.ln_1.bias', 'rwtransformer.h.6.attn.c_attn.weight', 'rwtransformer.h.6.attn.c_attn.bias', 'rwtransformer.h.6.attn.c_proj.weight', 'rwtransformer.h.6.attn.c_proj.bias', 'rwtransformer.h.6.ln_2.weight', 'rwtransformer.h.6.ln_2.bias', 'rwtransformer.h.6.mlp.c_fc.weight', 'rwtransformer.h.6.mlp.c_fc.bias', 'rwtransformer.h.6.mlp.c_proj.weight', 'rwtransformer.h.6.mlp.c_proj.bias', 'rwtransformer.h.7.ln_1.weight', 'rwtransformer.h.7.ln_1.bias', 'rwtransformer.h.7.attn.c_attn.weight', 'rwtransformer.h.7.attn.c_attn.bias', 'rwtransformer.h.7.attn.c_proj.weight', 'rwtransformer.h.7.attn.c_proj.bias', 'rwtransformer.h.7.ln_2.weight', 'rwtransformer.h.7.ln_2.bias', 'rwtransformer.h.7.mlp.c_fc.weight', 'rwtransformer.h.7.mlp.c_fc.bias', 'rwtransformer.h.7.mlp.c_proj.weight', 'rwtransformer.h.7.mlp.c_proj.bias', 'rwtransformer.h.8.ln_1.weight', 'rwtransformer.h.8.ln_1.bias', 'rwtransformer.h.8.attn.c_attn.weight', 'rwtransformer.h.8.attn.c_attn.bias', 'rwtransformer.h.8.attn.c_proj.weight', 'rwtransformer.h.8.attn.c_proj.bias', 'rwtransformer.h.8.ln_2.weight', 'rwtransformer.h.8.ln_2.bias', 'rwtransformer.h.8.mlp.c_fc.weight', 'rwtransformer.h.8.mlp.c_fc.bias', 'rwtransformer.h.8.mlp.c_proj.weight', 'rwtransformer.h.8.mlp.c_proj.bias', 'rwtransformer.h.9.ln_1.weight', 'rwtransformer.h.9.ln_1.bias', 'rwtransformer.h.9.attn.c_attn.weight', 'rwtransformer.h.9.attn.c_attn.bias', 'rwtransformer.h.9.attn.c_proj.weight', 'rwtransformer.h.9.attn.c_proj.bias', 'rwtransformer.h.9.ln_2.weight', 'rwtransformer.h.9.ln_2.bias', 'rwtransformer.h.9.mlp.c_fc.weight', 'rwtransformer.h.9.mlp.c_fc.bias', 'rwtransformer.h.9.mlp.c_proj.weight', 'rwtransformer.h.9.mlp.c_proj.bias', 'rwtransformer.h.10.ln_1.weight', 'rwtransformer.h.10.ln_1.bias', 'rwtransformer.h.10.attn.c_attn.weight', 'rwtransformer.h.10.attn.c_attn.bias', 'rwtransformer.h.10.attn.c_proj.weight', 'rwtransformer.h.10.attn.c_proj.bias', 'rwtransformer.h.10.ln_2.weight', 'rwtransformer.h.10.ln_2.bias', 'rwtransformer.h.10.mlp.c_fc.weight', 'rwtransformer.h.10.mlp.c_fc.bias', 'rwtransformer.h.10.mlp.c_proj.weight', 'rwtransformer.h.10.mlp.c_proj.bias', 'rwtransformer.h.11.ln_1.weight', 'rwtransformer.h.11.ln_1.bias', 'rwtransformer.h.11.attn.c_attn.weight', 'rwtransformer.h.11.attn.c_attn.bias', 'rwtransformer.h.11.attn.c_proj.weight', 'rwtransformer.h.11.attn.c_proj.bias', 'rwtransformer.h.11.ln_2.weight', 'rwtransformer.h.11.ln_2.bias', 'rwtransformer.h.11.mlp.c_fc.weight', 'rwtransformer.h.11.mlp.c_fc.bias', 'rwtransformer.h.11.mlp.c_proj.weight', 'rwtransformer.h.11.mlp.c_proj.bias', 'rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias']

As you can see, the difference between the model before and after reward model training (after removing the 'rwtransformer' prefix) is this:

list(set(model2_params) - set(model_params))
['v_head.weight']

It looks like somewhere, the bias is inserted in the state dict (I am not using LoRA for this one)

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdeespeed chatDeepSpeed Chat

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions