Skip to content

Commit 0e36f64

Browse files
authored
Merge pull request #64 from automl/gymnax-env-params-argument
add option to override gymnax env params
2 parents d768e29 + 23da3a2 commit 0e36f64

5 files changed

Lines changed: 15 additions & 3 deletions

File tree

arlbench/autorl/autorl_env.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
"env_name": "CartPole-v1",
3838
"env_kwargs": {},
3939
"eval_env_kwargs": {},
40+
"env_params": {},
41+
"env_eval_params": {},
4042
"n_envs": 10,
4143
"algorithm": "dqn",
4244
"cnn_policy": False,
@@ -107,6 +109,7 @@ def __init__(self, config: dict | None = None) -> None:
107109
env_kwargs=self._config["env_kwargs"],
108110
cnn_policy=self._config["cnn_policy"],
109111
seed=self._seed,
112+
env_params=self._config.get("env_params", None)
110113
)
111114

112115
self._eval_env = make_env(
@@ -116,6 +119,7 @@ def __init__(self, config: dict | None = None) -> None:
116119
env_kwargs=self._config["eval_env_kwargs"],
117120
cnn_policy=self._config["cnn_policy"],
118121
seed=self._seed + 1,
122+
env_params=self._config.get("env_eval_params", None)
119123
)
120124

121125
# Checkpointing

arlbench/core/environments/gymnax_env.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import gymnax
88
import jax
9+
from dataclasses import replace
910

1011
from .autorl_env import Environment
1112

@@ -17,7 +18,7 @@ class GymnaxEnv(Environment):
1718
"""A gymnax-based RL environment."""
1819

1920
def __init__(
20-
self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None
21+
self, env_name: str, n_envs: int, env_kwargs: dict[str, Any] | None = None, env_params: dict[str, Any] | None = None
2122
):
2223
"""Creates a gymnax environment for JAX-based RL training.
2324
@@ -29,7 +30,8 @@ def __init__(
2930
"""
3031
if env_kwargs is None:
3132
env_kwargs = {}
32-
env, env_params = gymnax.make(env_name, **env_kwargs)
33+
env, og_env_params = gymnax.make(env_name, **env_kwargs)
34+
env_params = replace(og_env_params, **(env_params or {}))
3335
super().__init__(env_name, env, n_envs)
3436

3537
self.env_params = env_params

arlbench/core/environments/make_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def make_env(
1919
n_envs: int = 1,
2020
seed: int = 0,
2121
env_kwargs: dict[str, Any] | None = None,
22+
env_params: dict[str, Any] | None = None,
2223
) -> Environment | Wrapper:
2324
"""ARLBench equivalent to make_env in gymnasium/gymnax etc.
2425
Creates a JAX-compatible RL environment.
@@ -50,7 +51,7 @@ def make_env(
5051
elif env_framework == "gymnax":
5152
from .gymnax_env import GymnaxEnv
5253

53-
env = GymnaxEnv(env_name, n_envs, env_kwargs=env_kwargs)
54+
env = GymnaxEnv(env_name, n_envs, env_kwargs=env_kwargs, env_params=env_params)
5455
elif env_framework == "envpool":
5556
from .envpool_env import EnvpoolEnv
5657

examples/configs/base.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ autorl:
2020
env_framework: ${environment.framework}
2121
env_name: ${environment.name}
2222
env_kwargs: ${environment.kwargs}
23+
env_params: ${environment.env_params}
2324
eval_env_kwargs: ${environment.eval_kwargs}
2425
n_envs: ${environment.n_envs}
2526
algorithm: ${algorithm}

examples/configs/environment/cc_cartpole.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@ cnn_policy: False
77
deterministic_eval: True
88
jax_enable_x64: False
99
n_envs: 8
10+
11+
env_params:
12+
masspole: 0.9
13+
length: 0.7

0 commit comments

Comments
 (0)