Skip to content

Commit 506bb7a

Browse files
authored
Add support for python object in python config for wrapper/callbacks (#479)
1 parent 633954f commit 506bb7a

File tree

9 files changed

+58
-16
lines changed

9 files changed

+58
-16
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
## Release 2.5.0a0 (WIP)
1+
## Release 2.5.0a1 (WIP)
22

33
### Breaking Changes
44
- Upgraded to Pytorch >= 2.3.0
55
- Upgraded to SB3 >= 2.5.0
66

77
### New Features
88
- Added support for Numpy v2
9+
- Added support for specifying callbacks and env wrapper as python object in python config files (instead of string)
910

1011
### Bug fixes
1112

rl_zoo3/enjoy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def enjoy() -> None: # noqa: C901
204204
obs = env.reset()
205205

206206
# Deterministic by default except for atari games
207-
stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic
207+
stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic)
208208
deterministic = not stochastic
209209

210210
episode_reward = 0.0

rl_zoo3/exp_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,8 +602,7 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
602602
if (
603603
"Neck" in self.env_name.gym_id
604604
or self.is_robotics_env(self.env_name.gym_id)
605-
or "parking-v0" in self.env_name.gym_id
606-
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
605+
or ("parking-v0" in self.env_name.gym_id and len(self.monitor_kwargs) == 0) # do not overwrite custom kwargs
607606
):
608607
self.monitor_kwargs = dict(info_keywords=("is_success",))
609608

rl_zoo3/push_to_hub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def package_to_hub(
398398
model = ALGOS[algo].load(model_path, env=eval_env, custom_objects=custom_objects, device=args.device, **kwargs)
399399

400400
# Deterministic by default except for atari games
401-
stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic
401+
stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic)
402402
deterministic = not stochastic
403403

404404
# Default model name, the model will be saved under "{algo}-{env_name}.zip"

rl_zoo3/record_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs)
132132

133133
# Deterministic by default except for atari games
134-
stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic
134+
stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic)
135135
deterministic = not stochastic
136136

137137
if video_folder is None:

rl_zoo3/utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,22 @@ def get_class_name(wrapper_name):
9999
kwargs = wrapper_dict[wrapper_name]
100100
else:
101101
kwargs = {}
102-
wrapper_module = importlib.import_module(get_module_name(wrapper_name))
103-
wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name))
102+
103+
if isinstance(wrapper_name, str):
104+
wrapper_module = importlib.import_module(get_module_name(wrapper_name))
105+
wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name))
106+
elif isinstance(wrapper_name, type):
107+
# No conversion needed
108+
wrapper_class = wrapper_name
109+
else:
110+
raise ValueError(
111+
f"Unexpected value {wrapper_name} for a {key}, must a str and a class, not {type(wrapper_name)}"
112+
)
113+
104114
wrapper_classes.append(wrapper_class)
105115
wrapper_kwargs.append(kwargs)
106116

107117
def wrap_env(env: gym.Env) -> gym.Env:
108-
"""
109-
:param env:
110-
:return:
111-
"""
112118
for wrapper_class, kwargs in zip(wrapper_classes, wrapper_kwargs):
113119
env = wrapper_class(env, **kwargs)
114120
return env
@@ -183,8 +189,12 @@ def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]:
183189
else:
184190
kwargs = {}
185191

186-
callback_class = get_class_by_name(callback_name)
187-
callbacks.append(callback_class(**kwargs))
192+
if isinstance(callback_name, BaseCallback):
193+
# No conversion needed
194+
callbacks.append(callback_name)
195+
else:
196+
callback_class = get_class_by_name(callback_name)
197+
callbacks.append(callback_class(**kwargs))
188198

189199
return callbacks
190200

rl_zoo3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.5.0a0
1+
2.5.0a1

tests/test_callbacks.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import shlex
22
import subprocess
33

4+
import pytest
5+
import stable_baselines3 as sb3
6+
7+
from rl_zoo3.utils import get_callback_list
8+
49

510
def _assert_eq(left, right):
611
assert left == right, f"{left} != {right}"
@@ -13,3 +18,26 @@ def test_raw_stat_callback(tmp_path):
1318
)
1419
return_code = subprocess.call(shlex.split(cmd))
1520
_assert_eq(return_code, 0)
21+
22+
23+
@pytest.mark.parametrize(
24+
"callback",
25+
[
26+
None,
27+
"rl_zoo3.callbacks.RawStatisticsCallback",
28+
[
29+
{"stable_baselines3.common.callbacks.StopTrainingOnMaxEpisodes": dict(max_episodes=3)},
30+
"rl_zoo3.callbacks.RawStatisticsCallback",
31+
],
32+
[sb3.common.callbacks.StopTrainingOnMaxEpisodes(3)],
33+
],
34+
)
35+
def test_get_callback(callback):
36+
hyperparams = {"callback": callback}
37+
callback_list = get_callback_list(hyperparams)
38+
if callback is None:
39+
assert len(callback_list) == 0
40+
elif isinstance(callback, str):
41+
assert len(callback_list) == 1
42+
else:
43+
assert len(callback_list) == len(callback)

tests/test_wrappers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import gymnasium as gym
22
import pytest
3+
import stable_baselines3 as sb3
34
from stable_baselines3 import A2C
45
from stable_baselines3.common.env_checker import check_env
56
from stable_baselines3.common.env_util import DummyVecEnv
67

7-
import rl_zoo3.import_envs # noqa: F401
8+
import rl_zoo3.import_envs
9+
import rl_zoo3.wrappers
810
from rl_zoo3.utils import get_wrapper_class
911
from rl_zoo3.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper
1012

@@ -24,6 +26,7 @@ def test_wrappers():
2426
None,
2527
{"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=2)},
2628
[{"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=3)}, "rl_zoo3.wrappers.TimeFeatureWrapper"],
29+
[{rl_zoo3.wrappers.HistoryWrapper: dict(horizon=3)}, "rl_zoo3.wrappers.TimeFeatureWrapper"],
2730
],
2831
)
2932
def test_get_wrapper(env_wrapper):
@@ -40,6 +43,7 @@ def test_get_wrapper(env_wrapper):
4043
[
4144
None,
4245
{"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=2)},
46+
{sb3.common.vec_env.VecFrameStack: dict(n_stack=2)},
4347
[{"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=3)}, "stable_baselines3.common.vec_env.VecMonitor"],
4448
],
4549
)

0 commit comments

Comments
 (0)