|
5 | 5 | from __future__ import annotations |
6 | 6 |
|
7 | 7 | import argparse |
| 8 | +import importlib.util |
8 | 9 | import inspect |
9 | 10 |
|
10 | 11 | import pytest |
|
16 | 17 | from torchrl.data.datasets.utils import load_dataset, register_dataset |
17 | 18 | from torchrl.data.replay_buffers.offline_to_online import prefill_replay_buffer |
18 | 19 | from torchrl.envs.libs.gym import _has_gym |
| 20 | +from torchrl.testing.gym_helpers import PENDULUM_VERSIONED |
19 | 21 |
|
20 | 22 | # Running a SAC loss requires a tensordict new enough to support |
21 | 23 | # ``to_module(preserve_module_state=...)``; the offline-to-online wiring itself |
22 | 24 | # does not. |
23 | 25 | _LOSS_RUNNABLE = ( |
24 | | - "preserve_module_state" |
25 | | - in __import__("inspect").signature(TensorDict.to_module).parameters |
| 26 | + "preserve_module_state" in inspect.signature(TensorDict.to_module).parameters |
| 27 | +) |
| 28 | +_CONFIGS_AVAILABLE = ( |
| 29 | + importlib.util.find_spec("hydra") is not None |
| 30 | + and importlib.util.find_spec("omegaconf") is not None |
26 | 31 | ) |
27 | 32 |
|
28 | 33 |
|
@@ -583,6 +588,10 @@ def test_constructor_exposes_sac_key_and_logging_kwargs(self): |
583 | 588 | ): |
584 | 589 | assert name in parameters |
585 | 590 |
|
| 591 | + @pytest.mark.skipif( |
| 592 | + not _CONFIGS_AVAILABLE, |
| 593 | + reason="Config system requires hydra-core and omegaconf", |
| 594 | + ) |
586 | 595 | def test_config_registered(self): |
587 | 596 | from torchrl.trainers.algorithms.configs import OfflineToOnlineTrainerConfig |
588 | 597 |
|
@@ -637,7 +646,7 @@ def test_train_grows_online_and_anneals(self, tmp_path): |
637 | 646 | from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer |
638 | 647 |
|
639 | 648 | torch.manual_seed(0) |
640 | | - env = GymEnv("Pendulum-v1") |
| 649 | + env = GymEnv(PENDULUM_VERSIONED()) |
641 | 650 | obs_dim = env.observation_spec["observation"].shape[-1] |
642 | 651 | action_dim = env.action_spec.shape[-1] |
643 | 652 |
|
|
0 commit comments