|
9 | 9 | import itertools |
10 | 10 | import os |
11 | 11 | import queue as queue_lib |
| 12 | +import sys |
12 | 13 | import time |
13 | 14 | import traceback |
| 15 | +import types |
14 | 16 | from functools import partial |
15 | 17 |
|
16 | 18 | import pytest |
17 | 19 | import torch |
18 | 20 | import torch.distributed as dist |
19 | 21 | import torchrl.testing.env_helper |
20 | | -from tensordict import assert_allclose_td |
| 22 | +from tensordict import assert_allclose_td, TensorDict |
21 | 23 | from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq |
22 | 24 | from torch import multiprocessing as mp |
23 | 25 |
|
|
28 | 30 | from torchrl.data.replay_buffers.samplers import SliceSampler |
29 | 31 | from torchrl.data.replay_buffers.storages import LazyTensorStorage |
30 | 32 | from torchrl.envs import InitTracker, RewardSum, StepCounter, TransformedEnv, VecNormV2 |
| 33 | +from torchrl.envs.libs import gym as gym_lib, isaac_lab as isaac_lab_lib |
| 34 | +from torchrl.envs.libs.isaac_lab import IsaacLabWrapper |
31 | 35 | from torchrl.envs.utils import check_env_specs |
32 | 36 | from torchrl.modules import LSTMModule, MLP |
33 | 37 | from torchrl.testing import get_default_devices |
@@ -304,6 +308,105 @@ def _isaaclab_direct_native_autoreset(env_name: str, num_envs: int = 16): |
304 | 308 | proc.join() |
305 | 309 |
|
306 | 310 |
|
| 311 | +def _install_fake_isaaclab(monkeypatch): |
| 312 | + class ManagerBasedEnv: |
| 313 | + pass |
| 314 | + |
| 315 | + class DirectRLEnv: |
| 316 | + pass |
| 317 | + |
| 318 | + class DirectMARLEnv: |
| 319 | + pass |
| 320 | + |
| 321 | + fake_envs = types.ModuleType("isaaclab.envs") |
| 322 | + fake_envs.ManagerBasedEnv = ManagerBasedEnv |
| 323 | + fake_envs.DirectRLEnv = DirectRLEnv |
| 324 | + fake_envs.DirectMARLEnv = DirectMARLEnv |
| 325 | + fake_isaaclab = types.ModuleType("isaaclab") |
| 326 | + fake_isaaclab.envs = fake_envs |
| 327 | + monkeypatch.setitem(sys.modules, "isaaclab", fake_isaaclab) |
| 328 | + monkeypatch.setitem(sys.modules, "isaaclab.envs", fake_envs) |
| 329 | + return ManagerBasedEnv, DirectRLEnv, DirectMARLEnv |
| 330 | + |
| 331 | + |
| 332 | +def test_isaaclab_direct_env_detection_is_native_autoreset_opt_in(monkeypatch): |
| 333 | + ManagerBasedEnv, DirectRLEnv, DirectMARLEnv = _install_fake_isaaclab(monkeypatch) |
| 334 | + monkeypatch.setattr(gym_lib, "_has_isaaclab", True) |
| 335 | + monkeypatch.setattr(isaac_lab_lib, "_has_isaaclab", True) |
| 336 | + |
| 337 | + manager_env = ManagerBasedEnv() |
| 338 | + direct_env = DirectRLEnv() |
| 339 | + direct_marl_env = DirectMARLEnv() |
| 340 | + |
| 341 | + assert IsaacLabWrapper._supports_native_autoreset(manager_env) |
| 342 | + assert not IsaacLabWrapper._supports_native_autoreset(direct_env) |
| 343 | + assert not IsaacLabWrapper._supports_native_autoreset(direct_marl_env) |
| 344 | + assert IsaacLabWrapper._supports_native_autoreset(direct_env, native_autoreset=True) |
| 345 | + assert IsaacLabWrapper._supports_native_autoreset( |
| 346 | + direct_marl_env, native_autoreset=True |
| 347 | + ) |
| 348 | + |
| 349 | + fake_vector = types.SimpleNamespace(VectorEnv=type("VectorEnv", (), {})) |
| 350 | + monkeypatch.setattr( |
| 351 | + gym_lib, |
| 352 | + "gym_backend", |
| 353 | + lambda name=None: fake_vector if name == "vector" else fake_vector, |
| 354 | + ) |
| 355 | + wrapper = gym_lib.GymWrapper.__new__(gym_lib.GymWrapper) |
| 356 | + wrapper._torchrl_native_autoreset_requested = False |
| 357 | + wrapper._env = types.SimpleNamespace(unwrapped=manager_env) |
| 358 | + assert wrapper._is_batched |
| 359 | + wrapper._env = types.SimpleNamespace(unwrapped=direct_env) |
| 360 | + assert not wrapper._is_batched |
| 361 | + wrapper._torchrl_native_autoreset_requested = True |
| 362 | + assert wrapper._is_batched |
| 363 | + |
| 364 | + isaac_wrapper = IsaacLabWrapper.__new__(IsaacLabWrapper) |
| 365 | + isaac_wrapper._env = types.SimpleNamespace(unwrapped=direct_env) |
| 366 | + assert not isaac_wrapper._supports_set_state |
| 367 | + |
| 368 | + def reset_to(*args, **kwargs): |
| 369 | + return None |
| 370 | + |
| 371 | + manager_env.reset_to = reset_to |
| 372 | + isaac_wrapper._env = types.SimpleNamespace(unwrapped=manager_env) |
| 373 | + assert isaac_wrapper._supports_set_state |
| 374 | + |
| 375 | + |
| 376 | +def test_isaaclab_observation_key_normalization_is_cached_and_non_clobbering(): |
| 377 | + env = IsaacLabWrapper.__new__(IsaacLabWrapper) |
| 378 | + env._rename_policy_to_observation = False |
| 379 | + policy = torch.ones(2, 3) |
| 380 | + observations = {"policy": policy} |
| 381 | + assert env._normalize_observation_keys(observations) is observations |
| 382 | + |
| 383 | + env._rename_policy_to_observation = True |
| 384 | + normalized = env._normalize_observation_keys(observations) |
| 385 | + assert normalized is not observations |
| 386 | + assert "policy" not in normalized |
| 387 | + assert normalized["observation"] is policy |
| 388 | + |
| 389 | + existing_observation = torch.zeros(2, 3) |
| 390 | + observations = {"policy": policy, "observation": existing_observation} |
| 391 | + assert env._normalize_observation_keys(observations) is observations |
| 392 | + assert observations["observation"] is existing_observation |
| 393 | + |
| 394 | + |
| 395 | +def test_isaaclab_all_false_reset_to_state_is_no_op(): |
| 396 | + env = IsaacLabWrapper.__new__(IsaacLabWrapper) |
| 397 | + td = TensorDict( |
| 398 | + { |
| 399 | + "_reset": torch.zeros(3, 1, dtype=torch.bool), |
| 400 | + "policy": torch.ones(3, 2), |
| 401 | + }, |
| 402 | + batch_size=(3,), |
| 403 | + ) |
| 404 | + out = env._reset(td, set_state=True, scene_state=object()) |
| 405 | + assert "_reset" not in out.keys() |
| 406 | + assert out is not td |
| 407 | + assert (out["policy"] == td["policy"]).all() |
| 408 | + |
| 409 | + |
307 | 410 | @pytest.mark.skipif(not _has_isaac, reason="IsaacGym not found") |
308 | 411 | @pytest.mark.parametrize( |
309 | 412 | "task", |
|
0 commit comments