Skip to content

Commit 06ab062

Browse files
authored
Add support for special VecEnv (brax, IsaacSim, ...) (#484)
* Allow to change default VecEnv * Add default vec env cls argument * Fix normalization loading for objects * Save exact command used when training, update changelog * Update HF api usage * Fix log-interval default behavior and upgrade to gym v1.1 * Add HF token to CI
1 parent 2e99bec commit 06ab062

File tree

10 files changed

+69
-15
lines changed

10 files changed

+69
-15
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ jobs:
1414
env:
1515
TERM: xterm-256color
1616
FORCE_COLOR: 1
17+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
1718
# Skip CI if [ci skip] in the commit message
1819
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
1920
runs-on: ubuntu-latest

.github/workflows/trained_agents.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
env:
1515
TERM: xterm-256color
1616
FORCE_COLOR: 1
17-
17+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
1818
# Skip CI if [ci skip] in the commit message
1919
if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')"
2020
runs-on: ubuntu-latest

CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1+
## Release 2.6.0a2 (WIP)
2+
3+
### Breaking Changes
4+
- Upgraded to SB3 >= 2.6.0
5+
6+
### New Features
7+
- Save the exact command line used to launch a training
8+
- Added support for special vectorized env (e.g. Brax, IsaacSim) by allowing to override the `VecEnv` class use to instantiate the env in the `ExperimentManager`
9+
- Allow to disable auto-logging by passing `--log-interval -2` (useful when logging things manually)
10+
- Added Gymnasium v1.1 support
11+
12+
### Bug fixes
13+
- Fixed use of old HF api in `get_hf_trained_models()`
14+
15+
### Documentation
16+
17+
### Other
18+
119
## Release 2.5.0 (2025-01-27)
220

321
### Breaking Changes

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
gym==0.26.2
2-
stable-baselines3[extra,tests,docs]>=2.5.0,<3.0
2+
stable-baselines3[extra,tests,docs]>=2.6.0a2,<3.0
33
box2d-py==2.3.8
4-
pybullet_envs_gymnasium>=0.5.0
4+
pybullet_envs_gymnasium>=0.6.0
55
# minigrid
66
cloudpickle>=2.2.1
77
# optuna plots:

rl_zoo3/enjoy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def enjoy() -> None: # noqa: C901
162162
should_render=not args.no_render,
163163
hyperparams=hyperparams,
164164
env_kwargs=env_kwargs,
165+
vec_env_cls=ExperimentManager.default_vec_env_cls,
165166
)
166167

167168
kwargs = dict(seed=args.seed)

rl_zoo3/exp_manager.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import importlib
33
import os
44
import pickle as pkl
5+
import sys
56
import time
67
import warnings
78
from collections import OrderedDict
@@ -60,6 +61,9 @@ class ExperimentManager:
6061
Please take a look at `train.py` to have the details for each argument.
6162
"""
6263

64+
# For special VecEnv like Brax, IsaacLab, ...
65+
default_vec_env_cls: Optional[type[VecEnv]] = None
66+
6367
def __init__(
6468
self,
6569
args: argparse.Namespace,
@@ -122,6 +126,10 @@ def __init__(
122126
self.optimization_log_path = optimization_log_path
123127

124128
self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
129+
# Override
130+
if self.default_vec_env_cls is not None:
131+
self.vec_env_class = self.default_vec_env_cls
132+
125133
self.vec_env_wrapper: Optional[Callable] = None
126134

127135
self.vec_env_kwargs: dict[str, Any] = {}
@@ -224,8 +232,13 @@ def learn(self, model: BaseAlgorithm) -> None:
224232
:param model: an initialized RL model
225233
"""
226234
kwargs: dict[str, Any] = {}
235+
# log_interval == -1 -> default
236+
# < -2 -> no auto-logging
227237
if self.log_interval > -1:
228238
kwargs = {"log_interval": self.log_interval}
239+
elif self.log_interval < -1:
240+
# Deactivate auto-logging, helpful when using callback like LogEveryNTimesteps
241+
kwargs = {"log_interval": None}
229242

230243
if len(self.callbacks) > 0:
231244
kwargs["callback"] = self.callbacks
@@ -288,6 +301,13 @@ def _save_config(self, saved_hyperparams: dict[str, Any]) -> None:
288301
ordered_args = OrderedDict([(key, vars(self.args)[key]) for key in sorted(vars(self.args).keys())])
289302
yaml.dump(ordered_args, f)
290303

304+
# Save command used to train
305+
command = "python3 " + " ".join(sys.argv)
306+
# Python 3.10+
307+
if hasattr(sys, "orig_argv"):
308+
command = " ".join(sys.orig_argv)
309+
(Path(self.params_path) / "command.txt").write_text(command)
310+
291311
print(f"Log path: {self.save_path}")
292312

293313
def read_hyperparameters(self) -> tuple[dict[str, Any], dict[str, Any]]:

rl_zoo3/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ def train() -> None:
3232
)
3333
parser.add_argument("-n", "--n-timesteps", help="Overwrite the number of timesteps", default=-1, type=int)
3434
parser.add_argument("--num-threads", help="Number of threads for PyTorch (-1 to use default)", default=-1, type=int)
35-
parser.add_argument("--log-interval", help="Override log interval (default: -1, no change)", default=-1, type=int)
35+
parser.add_argument(
36+
"--log-interval",
37+
help="Override log interval (default: -1, no change, -2: no logging useful when using custom logging freq)",
38+
default=-1,
39+
type=int,
40+
)
3641
parser.add_argument(
3742
"--eval-freq",
3843
help="Evaluate the agent every n steps (if negative, no evaluation). "

rl_zoo3/utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def create_test_env(
208208
should_render: bool = True,
209209
hyperparams: Optional[dict[str, Any]] = None,
210210
env_kwargs: Optional[dict[str, Any]] = None,
211+
vec_env_cls: Optional[type[VecEnv]] = None,
212+
vec_env_kwargs: Optional[dict[str, Any]] = None,
211213
) -> VecEnv:
212214
"""
213215
Create environment for testing a trained agent
@@ -220,6 +222,8 @@ def create_test_env(
220222
:param should_render: For Pybullet env, display the GUI
221223
:param hyperparams: Additional hyperparams (ex: n_stack)
222224
:param env_kwargs: Optional keyword argument to pass to the env constructor
225+
:param vec_env_cls: ``VecEnv`` class constructor.
226+
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
223227
:return:
224228
"""
225229
# Create the environment and wrap it if necessary
@@ -231,9 +235,9 @@ def create_test_env(
231235
if "env_wrapper" in hyperparams.keys():
232236
del hyperparams["env_wrapper"]
233237

234-
vec_env_kwargs: dict[str, Any] = {}
235238
# Avoid potential shared memory issue
236-
vec_env_cls = SubprocVecEnv if n_envs > 1 else DummyVecEnv
239+
if vec_env_cls is None:
240+
vec_env_cls = SubprocVecEnv if n_envs > 1 else DummyVecEnv
237241

238242
# Fix for gym 0.26, to keep old behavior
239243
env_kwargs = env_kwargs or {}
@@ -349,21 +353,24 @@ def get_hf_trained_models(organization: str = "sb3", check_filename: bool = Fals
349353
for model in models:
350354
# Try to extract algorithm and environment id from model card
351355
try:
352-
env_id = model.cardData["model-index"][0]["results"][0]["dataset"]["name"]
353-
algo = model.cardData["model-index"][0]["name"].lower()
356+
assert model.card_data is not None
357+
env_id = model.card_data["model-index"][0]["results"][0]["dataset"]["name"]
358+
algo = model.card_data["model-index"][0]["name"].lower()
354359
# RecurrentPPO alias is "ppo_lstm" in the rl zoo
355360
if algo == "recurrentppo":
356361
algo = "ppo_lstm"
357-
except (KeyError, IndexError):
358-
print(f"Skipping {model.modelId}")
362+
except (KeyError, IndexError, AssertionError):
363+
print(f"Skipping {model.id}")
359364
continue # skip model if name env id or algo name could not be found
360365

361366
env_name = EnvironmentName(env_id)
362367
model_name = ModelName(algo, env_name)
363368

364369
# check if there is a model file in the repo
365-
if check_filename and not any(f.rfilename == model_name.filename for f in api.model_info(model.modelId).siblings):
366-
continue # skip model if the repo contains no properly named model file
370+
if check_filename:
371+
maybe_siblings = api.model_info(model.id).siblings
372+
if maybe_siblings and not any(f.rfilename == model_name.filename for f in maybe_siblings):
373+
continue # skip model if the repo contains no properly named model file
367374

368375
trained_models[model_name] = (algo, env_id)
369376

@@ -422,6 +429,8 @@ def get_saved_hyperparams(
422429
normalize_kwargs = eval(hyperparams["normalize"])
423430
if test_mode:
424431
normalize_kwargs["norm_reward"] = norm_reward
432+
elif isinstance(hyperparams["normalize"], dict):
433+
normalize_kwargs = hyperparams["normalize"]
425434
else:
426435
normalize_kwargs = {"norm_obs": hyperparams["normalize"], "norm_reward": norm_reward}
427436
hyperparams["normalize_kwargs"] = normalize_kwargs

rl_zoo3/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.5.0
1+
2.6.0a2

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
See https://github.com/DLR-RM/rl-baselines3-zoo
1616
"""
1717
install_requires = [
18-
"sb3_contrib>=2.5.0,<3.0",
19-
"gymnasium>=0.29.1,<1.1.0",
18+
"sb3_contrib>=2.6.0a2,<3.0",
19+
"gymnasium>=0.29.1,<1.2.0",
2020
"huggingface_sb3>=3.0,<4.0",
2121
"tqdm",
2222
"rich",

0 commit comments

Comments
 (0)