Skip to content

Commit 9918741

Browse files
authored
[Test] Fix MuJoCo macro shape and gym Atari setup (#3811)
1 parent dd633fb commit 9918741

3 files changed

Lines changed: 25 additions & 1 deletion

File tree

.github/unittest/linux_libs/scripts_gym/run_all.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ printf "* Testing gym 0.20\n"
231231
pip install "pip<24.1" setuptools==65.3.0 wheel==0.38.4
232232
pip install 'gym[atari]==0.20'
233233
pip install 'ale-py==0.7.4'
234+
ale-import-roms Roms
234235
run_tests "gym==0.20" || true
235236
pip uninstall -y gym ale-py wheel || true
236237
pip install --upgrade pip setuptools wheel # restore latest versions

test/libs/test_mujoco.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def test_satellite_generic_macro_sequence(self):
117117
sequence = transform.action_sequence(
118118
td, MacroPrimitive.MOVE, target_qpos=target
119119
)
120-
assert sequence.shape == torch.Size([2, 4, env.action_spec.shape[-1]])
120+
assert sequence.shape == td.batch_size + torch.Size(
121+
[4, env.action_spec.shape[-1]]
122+
)
121123
env.close()
122124

123125
def test_humanoid_generic_macro_sequence(self):

torchrl/envs/libs/gym.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,26 @@
7575
"""
7676

7777

78+
def _patch_legacy_ale_py_gym_env(env_name: str) -> None:
79+
return
80+
81+
82+
@implement_for("gym", "0.20.0", "0.21.0")
83+
def _patch_legacy_ale_py_gym_env(env_name: str) -> None: # noqa: F811
84+
"""Expose the legacy ALE entry point used by gym 0.20 Atari specs."""
85+
del env_name
86+
try:
87+
ale_gym = importlib.import_module("ale_py.gym")
88+
except ImportError:
89+
return
90+
if hasattr(ale_gym, "ALGymEnv"):
91+
return
92+
try:
93+
ale_gym.ALGymEnv = gym_backend("envs.atari").AtariEnv
94+
except (AttributeError, ImportError):
95+
return
96+
97+
7898
def _gymnasium_reward_space(env):
7999
reward_space = getattr(env, "__dict__", {}).get("reward_space", None)
80100
if reward_space is not None:
@@ -1986,6 +2006,7 @@ def _build_env(
19862006
torchrl_logger.warning(
19872007
f"ale_py not found, this may cause issues with ALE environments: {err}"
19882008
)
2009+
_patch_legacy_ale_py_gym_env(env_name)
19892010
# we catch warnings as they may cause silent bugs
19902011
env = self.lib.make(env_name, **kwargs)
19912012
if len(w) and "frameskip" in str(w[-1].message):

0 commit comments

Comments
 (0)