diff --git a/recognition/arcface_torch/train.py b/recognition/arcface_torch/train.py index 9872783ab..2065c2e90 100755 --- a/recognition/arcface_torch/train.py +++ b/recognition/arcface_torch/train.py @@ -124,6 +124,7 @@ def main(args): backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) opt.load_state_dict(dict_checkpoint["state_optimizer"]) + dict_checkpoint['state_lr_scheduler']['max_steps'] = cfg.total_step lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) del dict_checkpoint