Skip to content

Commit 52cf596

Browse files
authored
[BugFix] Fix IsaacLab reset regressions (#3869)
1 parent e797c19 commit 52cf596

5 files changed

Lines changed: 245 additions & 44 deletions

File tree

test/envs/test_env_base.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torchrl.data.tensor_specs import Binary, Composite, NonTensor, Unbounded
2626
from torchrl.envs import EnvBase, ParallelEnv, SerialEnv
2727
from torchrl.envs.libs.gym import gym_backend, GymEnv
28-
from torchrl.envs.transforms import StepCounter, TransformedEnv
28+
from torchrl.envs.transforms import StepCounter, Transform, TransformedEnv
2929
from torchrl.envs.utils import check_env_specs, make_composite_from_td, step_mdp
3030
from torchrl.modules import Actor
3131
from torchrl.testing import (
@@ -969,6 +969,51 @@ def _set_seed(self, *args, **kwargs):
969969
...
970970

971971

972+
class _KwargOnlySetStateEnv(EnvBase):
973+
_supports_set_state = True
974+
975+
def __init__(self, **kwargs):
976+
super().__init__(batch_size=(), **kwargs)
977+
self.observation_spec = Composite(observation=Unbounded(shape=(1,)))
978+
self.action_spec = Unbounded(shape=(1,))
979+
self.reward_spec = Unbounded(shape=(1,))
980+
self.done_spec = Binary(n=1, shape=(1,), dtype=torch.bool)
981+
982+
def _input_td_has_state(self, tensordict):
983+
return False
984+
985+
def _reset(self, tensordict, **kwargs):
986+
if kwargs.get("set_state"):
987+
raise RuntimeError("unexpected implicit set_state")
988+
return TensorDict(
989+
{
990+
"observation": torch.zeros(1),
991+
"done": torch.zeros(1, dtype=torch.bool),
992+
},
993+
batch_size=(),
994+
)
995+
996+
def _step(self, tensordict):
997+
return TensorDict(
998+
{
999+
"observation": tensordict["observation"] + tensordict["action"],
1000+
"reward": torch.zeros(1),
1001+
"done": torch.zeros(1, dtype=torch.bool),
1002+
},
1003+
batch_size=(),
1004+
)
1005+
1006+
def _set_seed(self, *args, **kwargs):
1007+
...
1008+
1009+
1010+
class _OuterStateTransform(Transform):
1011+
def transform_state_spec(self, state_spec):
1012+
state_spec = state_spec.clone()
1013+
state_spec["outer_state"] = Unbounded(shape=(1,))
1014+
return state_spec
1015+
1016+
9721017
class TestResetSetState:
9731018
"""Tests for the explicit ``reset(td, set_state=True)`` deterministic-reset kwarg."""
9741019

@@ -1029,6 +1074,14 @@ def test_set_state_batched_parallel(self, maybe_fork_ParallelEnv):
10291074
finally:
10301075
env.close()
10311076

1077+
def test_transformed_env_delegates_implicit_state_detection(self):
1078+
env = TransformedEnv(_KwargOnlySetStateEnv(), _OuterStateTransform())
1079+
td = TensorDict({"outer_state": torch.ones(1)}, batch_size=())
1080+
with warnings.catch_warnings():
1081+
warnings.simplefilter("error", FutureWarning)
1082+
out = env.reset(td)
1083+
assert (out["observation"] == 0).all()
1084+
10321085

10331086
if __name__ == "__main__":
10341087
args, unknown = argparse.ArgumentParser().parse_known_args()

test/libs/test_isaac.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
import itertools
1010
import os
1111
import queue as queue_lib
12+
import sys
1213
import time
1314
import traceback
15+
import types
1416
from functools import partial
1517

1618
import pytest
1719
import torch
1820
import torch.distributed as dist
1921
import torchrl.testing.env_helper
20-
from tensordict import assert_allclose_td
22+
from tensordict import assert_allclose_td, TensorDict
2123
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
2224
from torch import multiprocessing as mp
2325

@@ -28,6 +30,8 @@
2830
from torchrl.data.replay_buffers.samplers import SliceSampler
2931
from torchrl.data.replay_buffers.storages import LazyTensorStorage
3032
from torchrl.envs import InitTracker, RewardSum, StepCounter, TransformedEnv, VecNormV2
33+
from torchrl.envs.libs import gym as gym_lib, isaac_lab as isaac_lab_lib
34+
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
3135
from torchrl.envs.utils import check_env_specs
3236
from torchrl.modules import LSTMModule, MLP
3337
from torchrl.testing import get_default_devices
@@ -304,6 +308,105 @@ def _isaaclab_direct_native_autoreset(env_name: str, num_envs: int = 16):
304308
proc.join()
305309

306310

311+
def _install_fake_isaaclab(monkeypatch):
312+
class ManagerBasedEnv:
313+
pass
314+
315+
class DirectRLEnv:
316+
pass
317+
318+
class DirectMARLEnv:
319+
pass
320+
321+
fake_envs = types.ModuleType("isaaclab.envs")
322+
fake_envs.ManagerBasedEnv = ManagerBasedEnv
323+
fake_envs.DirectRLEnv = DirectRLEnv
324+
fake_envs.DirectMARLEnv = DirectMARLEnv
325+
fake_isaaclab = types.ModuleType("isaaclab")
326+
fake_isaaclab.envs = fake_envs
327+
monkeypatch.setitem(sys.modules, "isaaclab", fake_isaaclab)
328+
monkeypatch.setitem(sys.modules, "isaaclab.envs", fake_envs)
329+
return ManagerBasedEnv, DirectRLEnv, DirectMARLEnv
330+
331+
332+
def test_isaaclab_direct_env_detection_is_native_autoreset_opt_in(monkeypatch):
333+
ManagerBasedEnv, DirectRLEnv, DirectMARLEnv = _install_fake_isaaclab(monkeypatch)
334+
monkeypatch.setattr(gym_lib, "_has_isaaclab", True)
335+
monkeypatch.setattr(isaac_lab_lib, "_has_isaaclab", True)
336+
337+
manager_env = ManagerBasedEnv()
338+
direct_env = DirectRLEnv()
339+
direct_marl_env = DirectMARLEnv()
340+
341+
assert IsaacLabWrapper._supports_native_autoreset(manager_env)
342+
assert not IsaacLabWrapper._supports_native_autoreset(direct_env)
343+
assert not IsaacLabWrapper._supports_native_autoreset(direct_marl_env)
344+
assert IsaacLabWrapper._supports_native_autoreset(direct_env, native_autoreset=True)
345+
assert IsaacLabWrapper._supports_native_autoreset(
346+
direct_marl_env, native_autoreset=True
347+
)
348+
349+
fake_vector = types.SimpleNamespace(VectorEnv=type("VectorEnv", (), {}))
350+
monkeypatch.setattr(
351+
gym_lib,
352+
"gym_backend",
353+
lambda name=None: fake_vector if name == "vector" else fake_vector,
354+
)
355+
wrapper = gym_lib.GymWrapper.__new__(gym_lib.GymWrapper)
356+
wrapper._torchrl_native_autoreset_requested = False
357+
wrapper._env = types.SimpleNamespace(unwrapped=manager_env)
358+
assert wrapper._is_batched
359+
wrapper._env = types.SimpleNamespace(unwrapped=direct_env)
360+
assert not wrapper._is_batched
361+
wrapper._torchrl_native_autoreset_requested = True
362+
assert wrapper._is_batched
363+
364+
isaac_wrapper = IsaacLabWrapper.__new__(IsaacLabWrapper)
365+
isaac_wrapper._env = types.SimpleNamespace(unwrapped=direct_env)
366+
assert not isaac_wrapper._supports_set_state
367+
368+
def reset_to(*args, **kwargs):
369+
return None
370+
371+
manager_env.reset_to = reset_to
372+
isaac_wrapper._env = types.SimpleNamespace(unwrapped=manager_env)
373+
assert isaac_wrapper._supports_set_state
374+
375+
376+
def test_isaaclab_observation_key_normalization_is_cached_and_non_clobbering():
377+
env = IsaacLabWrapper.__new__(IsaacLabWrapper)
378+
env._rename_policy_to_observation = False
379+
policy = torch.ones(2, 3)
380+
observations = {"policy": policy}
381+
assert env._normalize_observation_keys(observations) is observations
382+
383+
env._rename_policy_to_observation = True
384+
normalized = env._normalize_observation_keys(observations)
385+
assert normalized is not observations
386+
assert "policy" not in normalized
387+
assert normalized["observation"] is policy
388+
389+
existing_observation = torch.zeros(2, 3)
390+
observations = {"policy": policy, "observation": existing_observation}
391+
assert env._normalize_observation_keys(observations) is observations
392+
assert observations["observation"] is existing_observation
393+
394+
395+
def test_isaaclab_all_false_reset_to_state_is_no_op():
396+
env = IsaacLabWrapper.__new__(IsaacLabWrapper)
397+
td = TensorDict(
398+
{
399+
"_reset": torch.zeros(3, 1, dtype=torch.bool),
400+
"policy": torch.ones(3, 2),
401+
},
402+
batch_size=(3,),
403+
)
404+
out = env._reset(td, set_state=True, scene_state=object())
405+
assert "_reset" not in out.keys()
406+
assert out is not td
407+
assert (out["policy"] == td["policy"]).all()
408+
409+
307410
@pytest.mark.skipif(not _has_isaac, reason="IsaacGym not found")
308411
@pytest.mark.parametrize(
309412
"task",

torchrl/envs/libs/gym.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,9 @@ def __call__(cls, *args, **kwargs):
873873
missing_obs_value = kwargs.pop("missing_obs_value", None)
874874
native_autoreset = kwargs.pop("native_autoreset", False)
875875
num_workers = kwargs.pop("num_workers", 1)
876+
native_autoreset = kwargs.setdefault(
877+
"_torchrl_native_autoreset_requested", native_autoreset
878+
)
876879

877880
if cls.__name__ == "GymEnv" and num_workers > 1:
878881
from torchrl.envs import EnvCreator, ParallelEnv
@@ -903,7 +906,9 @@ def __call__(cls, *args, **kwargs):
903906
kwargs = {}
904907
if missing_obs_value is not None:
905908
kwargs["missing_obs_value"] = missing_obs_value
906-
if IsaacLabWrapper._supports_native_autoreset(instance._env.unwrapped):
909+
if IsaacLabWrapper._supports_native_autoreset(
910+
instance._env.unwrapped, native_autoreset=native_autoreset
911+
):
907912
env = TransformedEnv(
908913
instance,
909914
VecGymEnvTransform(**kwargs, native_autoreset=native_autoreset),
@@ -1131,6 +1136,9 @@ def get_library_name(env) -> str:
11311136
)
11321137

11331138
def __init__(self, env=None, categorical_action_encoding=False, **kwargs):
1139+
self._torchrl_native_autoreset_requested = kwargs.pop(
1140+
"_torchrl_native_autoreset_requested", False
1141+
)
11341142
self._seed_calls_reset = None
11351143
self._categorical_action_encoding = categorical_action_encoding
11361144
if env is not None:
@@ -1204,7 +1212,10 @@ def _is_batched(self):
12041212
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
12051213

12061214
tuple_of_classes = (
1207-
tuple_of_classes + IsaacLabWrapper._supported_isaac_env_classes()
1215+
tuple_of_classes
1216+
+ IsaacLabWrapper._supported_isaac_env_classes(
1217+
include_direct=self._torchrl_native_autoreset_requested
1218+
)
12081219
)
12091220
return isinstance(
12101221
self._env.unwrapped, tuple_of_classes + (gym_backend("vector").VectorEnv,)

0 commit comments

Comments
 (0)