Skip to content

Size mismatch when loading pre-trained models  #37

@Rich2333

Description

@Rich2333

Hi,

When I try to load pre-trained models to test predict.py, I was noticed as follows:

python predict.py pre-trained/final-energy-per-atom.pth.tar mp/
=> loading model params 'pre-trained/final-energy-per-atom.pth.tar'
=> loaded model params 'pre-trained/final-energy-per-atom.pth.tar'
=> loading model 'pre-trained/final-energy-per-atom.pth.tar'
Traceback (most recent call last):
File "E:\cgcnn-master\predict.py", line 298, in
main()
File "E:\cgcnn-master\predict.py", line 94, in main
model.load_state_dict(checkpoint['state_dict'])
File "C:\ProgramData\Anaconda3\envs\cgcnn1\lib\site-packages\torch\nn\modules\module.py", line 1497, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CrystalGraphConvNet:
size mismatch for convs.0.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.1.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.2.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).
size mismatch for convs.3.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).

btw, then I tried to train my own model and use it to predict. The errors above didn't show up, but I got a TOO large MAE.

(cgcnn) E:\cgcnn-master>python predict.py E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar mp/
=> loading model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loaded model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loading model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar'
=> loaded model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar' (epoch 484, validation 0.05862389877438545)
C:\ProgramData\Anaconda3\envs\cgcnn\lib\site-packages\pymatgen\io\cif.py:1155: UserWarning: Issues encountered while parsing CIF: Some fractional coordinates rounded to ideal values to avoid issues with finite precision.
warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings))
Test: [0/74] Time 26.633 (26.633) Loss inf (inf) MAE 5.977 (5.977)
Test: [10/74] Time 24.787 (27.052) Loss inf (inf) MAE 6.005 (6.013)
Test: [20/74] Time 28.383 (28.096) Loss inf (inf) MAE 5.941 (6.010)
Test: [30/74] Time 31.305 (28.518) Loss inf (inf) MAE 6.081 (6.008)
Test: [40/74] Time 30.491 (29.037) Loss inf (inf) MAE 5.860 (6.010)
Test: [50/74] Time 35.822 (29.651) Loss inf (inf) MAE 6.035 (6.008)
Test: [60/74] Time 33.488 (30.191) Loss inf (inf) MAE 6.033 (6.012)
Test: [70/74] Time 34.823 (30.565) Loss inf (inf) MAE 5.955 (6.008)
** MAE 6.009

Thanks for your attention!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions