RuntimeError: Error(s) in loading state_dict: Unexpected key(s) when recovering results from main process during Trainer.fit() #246
Description
I am trying to get multi-gpu training working for running tuning with Ray[Tune]. However, I am getting the following error:
File "/home/davina/Private/repos/autopopulus/autopopulus/models/ap.py", line 182, in _fit
self.trainer.fit(self.ae, datamodule=data)
File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 838, in fit
self._call_and_handle_interrupt(
File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 783, in _call_and_handle_interrupt
return self.strategy.launcher.launch(
File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/ray_lightning/launchers/ray_launcher.py", line 73, in launch
self._recover_results_in_main_process(ray_output, trainer)
File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/ray_lightning/launchers/ray_launcher.py", line 399, in _recover_results_in_main_process
trainer.lightning_module.load_state_dict(state_dict)
File "/home/davina/mambaforge/envs/ap/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for AEDitto:
Unexpected key(s) in state_dict: "fc_mu.weight", "fc_mu.bias", "fc_var.weight", "fc_var.bias", "decoder.0.weight", "decoder.0.bias".
Looking at the following lines in RayLauncher:_recover_results_in_main_process()
:
if ray_output.weights_path is not None:
state_stream = ray_output.weights_path
# DDPSpawnPlugin.__recover_child_process_weights begin
# Difference here is that instead of writing the model weights to a
# file and loading it, we use the state dict of the model directly.
state_dict = load_state_stream(state_stream, to_gpu=self._strategy.use_gpu)
# Set the state for PTL using the output from remote training.
trainer.lightning_module.load_state_dict(state_dict)
If I probe for what state_dict
is vs trainer.lightning_module.state_dict()
it seems that the latter is completely empty, it just outputs OrderedDict()
. The former has all the weights listed in the error with actual data. So for some reason the lightning module is not being set up (or something like that?) for it to have no state. This is not an issue when I don't use ray_lightning for 1-gpu-per-trial and just normal ray[tune].
For reference of how I'm running tuning.
Other info:
ray-core 2.2.0 py39h4d85f9a_1 conda-forge
ray-dashboard 2.2.0 py39h9a2ef2b_1 conda-forge
ray-default 2.2.0 py39hf3d152e_1 conda-forge
ray-lightning 0.3.0 pypi_0 pypi
ray-tune 2.2.0 py39hf3d152e_1 conda-forge
Python 3.9.15
OS: Ubuntu 18.04.4 LTS (Bionic Beaver)
Other relevant information:
cudatoolkit=10.2
pytorch=1.12.1
pytorch-lightning=1.6.5
cudnn=7.6.5
Specs:
4 GeForce RTX 2080 Ti's
32 CPUs (x86_64)