@@ -742,13 +742,13 @@ def load( # noqa: C901
742
742
# put state_dicts back in place
743
743
model .set_parameters (params , exact_match = True , device = device )
744
744
except RuntimeError as e :
745
- # Patch to load Policy saved using SB3 < 1.7.0
745
+ # Patch to load policies saved using SB3 < 1.7.0
746
746
# the error is probably due to old policy being loaded
747
747
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
748
748
if "pi_features_extractor" in str (e ) and "Missing key(s) in state_dict" in str (e ):
749
749
model .set_parameters (params , exact_match = False , device = device )
750
750
warnings .warn (
751
- "You are probably loading a model saved with SB3 < 1.7.0, "
751
+ "You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, "
752
752
"we deactivated exact_match so you can save the model "
753
753
"again to avoid issues in the future "
754
754
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
@@ -757,6 +757,29 @@ def load( # noqa: C901
757
757
)
758
758
else :
759
759
raise e
760
+ except ValueError as e :
761
+ # Patch to load DQN policies saved using SB3 < 2.4.0
762
+ # The target network params are no longer in the optimizer
763
+ # See https://github.com/DLR-RM/stable-baselines3/pull/1963
764
+ saved_optim_params = params ["policy.optimizer" ]["param_groups" ][0 ]["params" ] # type: ignore[index]
765
+ n_params_saved = len (saved_optim_params )
766
+ n_params = len (model .policy .optimizer .param_groups [0 ]["params" ])
767
+ if n_params_saved == 2 * n_params :
768
+ # Truncate to include only online network params
769
+ params ["policy.optimizer" ]["param_groups" ][0 ]["params" ] = saved_optim_params [:n_params ] # type: ignore[index]
770
+
771
+ model .set_parameters (params , exact_match = True , device = device )
772
+ warnings .warn (
773
+ "You are probably loading a DQN model saved with SB3 < 2.4.0, "
774
+ "we truncated the optimizer state so you can save the model "
775
+ "again to avoid issues in the future "
776
+ "(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). "
777
+ f"Original error: { e } \n "
778
+ "Note: the model should still work fine, this only a warning."
779
+ )
780
+ else :
781
+ raise e
782
+
760
783
# put other pytorch variables back in place
761
784
if pytorch_variables is not None :
762
785
for name in pytorch_variables :
0 commit comments