Skip to content

Commit bd3c0c6

Browse files
authored
Fix loading of optimizer with older DQN models (#1978)
1 parent 000544c commit bd3c0c6

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

docs/misc/changelog.rst

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,16 @@
33
Changelog
44
==========
55

6-
Release 2.4.0a6 (WIP)
6+
Release 2.4.0a7 (WIP)
77
--------------------------
88

9+
.. note::
10+
11+
DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
12+
truncation of optimizer state when loaded with SB3 >= 2.4.0.
13+
To suppress the warning, simply save the model again.
14+
You can find more info in `PR #1963 <https://github.com/DLR-RM/stable-baselines3/pull/1963>`_
15+
916
Breaking Changes:
1017
^^^^^^^^^^^^^^^^^
1118

@@ -28,9 +35,11 @@ Bug Fixes:
2835

2936
`RL Zoo`_
3037
^^^^^^^^^
38+
- Updated defaults hyperparameters for TQC/SAC for Swimmer-v4 (decrease gamma for more consistent results)
3139

3240
`SBX`_ (SB3 + Jax)
3341
^^^^^^^^^^^^^^^^^^
42+
- Added CNN support for DQN
3443

3544
Deprecations:
3645
^^^^^^^^^^^^^

stable_baselines3/common/base_class.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -742,13 +742,13 @@ def load( # noqa: C901
742742
# put state_dicts back in place
743743
model.set_parameters(params, exact_match=True, device=device)
744744
except RuntimeError as e:
745-
# Patch to load Policy saved using SB3 < 1.7.0
745+
# Patch to load policies saved using SB3 < 1.7.0
746746
# the error is probably due to old policy being loaded
747747
# See https://github.com/DLR-RM/stable-baselines3/issues/1233
748748
if "pi_features_extractor" in str(e) and "Missing key(s) in state_dict" in str(e):
749749
model.set_parameters(params, exact_match=False, device=device)
750750
warnings.warn(
751-
"You are probably loading a model saved with SB3 < 1.7.0, "
751+
"You are probably loading a A2C/PPO model saved with SB3 < 1.7.0, "
752752
"we deactivated exact_match so you can save the model "
753753
"again to avoid issues in the future "
754754
"(see https://github.com/DLR-RM/stable-baselines3/issues/1233 for more info). "
@@ -757,6 +757,29 @@ def load( # noqa: C901
757757
)
758758
else:
759759
raise e
760+
except ValueError as e:
761+
# Patch to load DQN policies saved using SB3 < 2.4.0
762+
# The target network params are no longer in the optimizer
763+
# See https://github.com/DLR-RM/stable-baselines3/pull/1963
764+
saved_optim_params = params["policy.optimizer"]["param_groups"][0]["params"] # type: ignore[index]
765+
n_params_saved = len(saved_optim_params)
766+
n_params = len(model.policy.optimizer.param_groups[0]["params"])
767+
if n_params_saved == 2 * n_params:
768+
# Truncate to include only online network params
769+
params["policy.optimizer"]["param_groups"][0]["params"] = saved_optim_params[:n_params] # type: ignore[index]
770+
771+
model.set_parameters(params, exact_match=True, device=device)
772+
warnings.warn(
773+
"You are probably loading a DQN model saved with SB3 < 2.4.0, "
774+
"we truncated the optimizer state so you can save the model "
775+
"again to avoid issues in the future "
776+
"(see https://github.com/DLR-RM/stable-baselines3/pull/1963 for more info). "
777+
f"Original error: {e} \n"
778+
"Note: the model should still work fine, this only a warning."
779+
)
780+
else:
781+
raise e
782+
760783
# put other pytorch variables back in place
761784
if pytorch_variables is not None:
762785
for name in pytorch_variables:

stable_baselines3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a6
1+
2.4.0a7

tests/test_save_load.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def test_save_load_env_cnn(tmp_path, model_class):
340340
# clear file from os
341341
os.remove(tmp_path / "test_save.zip")
342342

343-
# Check we can load models saved with SB3 < 1.7.0
343+
# Check we can load A2C/PPO models saved with SB3 < 1.7.0
344344
if model_class == A2C:
345345
del model.policy.pi_features_extractor
346346
model.save(tmp_path / "test_save")
@@ -809,3 +809,15 @@ def test_save_load_net_arch_none(tmp_path):
809809
# None has been replaced by the default net arch
810810
assert model.policy.net_arch is not None
811811
os.remove(tmp_path / "ppo.zip")
812+
813+
814+
def test_save_load_no_target_params(tmp_path):
815+
# Check we can load DQN models saved with SB3 < 2.4.0
816+
model = DQN("MlpPolicy", "CartPole-v1", buffer_size=10000, learning_starts=4)
817+
env = model.get_env()
818+
# Include target net params
819+
model.policy.optimizer = th.optim.Adam(model.policy.parameters(), lr=0.001)
820+
model.save(tmp_path / "test_save")
821+
with pytest.warns(UserWarning):
822+
DQN.load(str(tmp_path / "test_save.zip"), env=env).learn(20)
823+
os.remove(tmp_path / "test_save.zip")

0 commit comments

Comments
 (0)