Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

RuntimeError: Error(s) in loading state_dict: Unexpected key(s) when recovering results from main process during Trainer.fit() #246

Open
@davzaman

Description

I am trying to get multi-gpu training working for running tuning with Ray[Tune]. However, I am getting the following error:

  File "/home/davina/Private/repos/autopopulus/autopopulus/models/ap.py", line 182, in _fit
    self.trainer.fit(self.ae, datamodule=data)
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 838, in fit
    self._call_and_handle_interrupt(
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 783, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/ray_lightning/launchers/ray_launcher.py", line 73, in launch
    self._recover_results_in_main_process(ray_output, trainer)
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/ray_lightning/launchers/ray_launcher.py", line 399, in _recover_results_in_main_process
    trainer.lightning_module.load_state_dict(state_dict)
  File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AEDitto:
        Unexpected key(s) in state_dict: "fc_mu.weight", "fc_mu.bias", "fc_var.weight", "fc_var.bias", "decoder.0.weight", "decoder.0.bias".

Looking at the following lines in RayLauncher:_recover_results_in_main_process():

if ray_output.weights_path is not None:
            state_stream = ray_output.weights_path
            # DDPSpawnPlugin.__recover_child_process_weights begin
            # Difference here is that instead of writing the model weights to a
            # file and loading it, we use the state dict of the model directly.
            state_dict = load_state_stream(state_stream, to_gpu=self._strategy.use_gpu)
            # Set the state for PTL using the output from remote training.
            trainer.lightning_module.load_state_dict(state_dict)

If I probe for what state_dict is vs trainer.lightning_module.state_dict() it seems that the latter is completely empty, it just outputs OrderedDict(). The former has all the weights listed in the error with actual data. So for some reason the lightning module is not being set up (or something like that?) for it to have no state. This is not an issue when I don't use ray_lightning for 1-gpu-per-trial and just normal ray[tune].

For reference of how I'm running tuning.

Other info:

ray-core                  2.2.0            py39h4d85f9a_1    conda-forge
ray-dashboard             2.2.0            py39h9a2ef2b_1    conda-forge
ray-default               2.2.0            py39hf3d152e_1    conda-forge
ray-lightning             0.3.0                    pypi_0    pypi
ray-tune                  2.2.0            py39hf3d152e_1    conda-forge

Python 3.9.15

OS: Ubuntu 18.04.4 LTS (Bionic Beaver)

Other relevant information:
cudatoolkit=10.2
pytorch=1.12.1
pytorch-lightning=1.6.5
cudnn=7.6.5

Specs:
4 GeForce RTX 2080 Ti's
32 CPUs (x86_64)

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions