Skip to content

Commit 377437b

Browse files
committed
[BugFix] Fix offline-to-online CI failures
1 parent c50a082 commit 377437b

7 files changed

Lines changed: 47 additions & 26 deletions

File tree

.github/unittest/linux_libs/scripts_ataridqn/install.sh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive
2929
printf "Installing PyTorch with cu128"
3030
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3131
if [ "${CU_VERSION:-}" == cpu ] ; then
32-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
32+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps
3334
else
34-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
36+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps
3537
fi
3638
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3739
if [ "${CU_VERSION:-}" == cpu ] ; then
38-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
41+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu --no-deps
3942
else
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
43+
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
44+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 --no-deps
4145
fi
4246
else
4347
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_gen-dgrl/install.sh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive
2929
printf "Installing PyTorch with cu128"
3030
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3131
if [ "${CU_VERSION:-}" == cpu ] ; then
32-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
32+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps
3334
else
34-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
36+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps
3537
fi
3638
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3739
if [ "${CU_VERSION:-}" == cpu ] ; then
38-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
40+
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
41+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu --no-deps
3942
else
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128
43+
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
44+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 --no-deps
4145
fi
4246
else
4347
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_openx/install.sh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive
2929
printf "Installing PyTorch with cu128"
3030
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3131
if [ "${CU_VERSION:-}" == cpu ] ; then
32-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
32+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps
3334
else
34-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
36+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps
3537
fi
3638
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3739
if [ "${CU_VERSION:-}" == cpu ] ; then
38-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
40+
pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
41+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu -U --no-deps
3942
else
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U
43+
pip3 install torch --index-url https://download.pytorch.org/whl/cu128 -U
44+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 -U --no-deps
4145
fi
4246
else
4347
printf "Failed to install pytorch"

.github/unittest/linux_libs/scripts_vd4rl/install.sh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@ git submodule sync && git submodule update --init --recursive
2929
printf "Installing PyTorch with cu128"
3030
if [[ "$TORCH_VERSION" == "nightly" ]]; then
3131
if [ "${CU_VERSION:-}" == cpu ] ; then
32-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U
32+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
33+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cpu -U --no-deps
3334
else
34-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U
35+
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
36+
pip3 install --pre torchvision --index-url https://download.pytorch.org/whl/nightly/cu128 -U --no-deps
3537
fi
3638
elif [[ "$TORCH_VERSION" == "stable" ]]; then
3739
if [ "${CU_VERSION:-}" == cpu ] ; then
38-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu -U
40+
pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
41+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cpu -U --no-deps
3942
else
40-
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu128 -U
43+
pip3 install torch --index-url https://download.pytorch.org/whl/cu128 -U
44+
pip3 install torchvision --index-url https://download.pytorch.org/whl/cu128 -U --no-deps
4145
fi
4246
else
4347
printf "Failed to install pytorch"

test/test_offline_to_online.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import argparse
8+
import importlib.util
89
import inspect
910

1011
import pytest
@@ -16,13 +17,17 @@
1617
from torchrl.data.datasets.utils import load_dataset, register_dataset
1718
from torchrl.data.replay_buffers.offline_to_online import prefill_replay_buffer
1819
from torchrl.envs.libs.gym import _has_gym
20+
from torchrl.testing.gym_helpers import PENDULUM_VERSIONED
1921

2022
# Running a SAC loss requires a tensordict new enough to support
2123
# ``to_module(preserve_module_state=...)``; the offline-to-online wiring itself
2224
# does not.
2325
_LOSS_RUNNABLE = (
24-
"preserve_module_state"
25-
in __import__("inspect").signature(TensorDict.to_module).parameters
26+
"preserve_module_state" in inspect.signature(TensorDict.to_module).parameters
27+
)
28+
_CONFIGS_AVAILABLE = (
29+
importlib.util.find_spec("hydra") is not None
30+
and importlib.util.find_spec("omegaconf") is not None
2631
)
2732

2833

@@ -583,6 +588,10 @@ def test_constructor_exposes_sac_key_and_logging_kwargs(self):
583588
):
584589
assert name in parameters
585590

591+
@pytest.mark.skipif(
592+
not _CONFIGS_AVAILABLE,
593+
reason="Config system requires hydra-core and omegaconf",
594+
)
586595
def test_config_registered(self):
587596
from torchrl.trainers.algorithms.configs import OfflineToOnlineTrainerConfig
588597

@@ -637,7 +646,7 @@ def test_train_grows_online_and_anneals(self, tmp_path):
637646
from torchrl.trainers.algorithms.offline_to_online import OfflineToOnlineTrainer
638647

639648
torch.manual_seed(0)
640-
env = GymEnv("Pendulum-v1")
649+
env = GymEnv(PENDULUM_VERSIONED())
641650
obs_dim = env.observation_spec["observation"].shape[-1]
642651
action_dim = env.action_spec.shape[-1]
643652

torchrl/trainers/algorithms/configs/trainers.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,7 @@ def _make_sac_trainer(*args, **kwargs) -> SACTrainer:
221221

222222
@dataclass
223223
class OfflineToOnlineTrainerConfig(SACTrainerConfig):
224-
"""Hydra configuration for
225-
:class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`.
224+
"""Hydra configuration for :class:`~torchrl.trainers.algorithms.OfflineToOnlineTrainer`.
226225
227226
Every kwarg accepted by ``OfflineToOnlineTrainer.__init__`` is exposed as a
228227
field here, with SAC network-construction helper fields inherited from
@@ -267,9 +266,7 @@ def _make_offline_to_online_trainer(*args, **kwargs) -> OfflineToOnlineTrainer:
267266
target_net_updater = kwargs.pop("target_net_updater")
268267
async_collection = kwargs.pop("async_collection", False)
269268
if async_collection:
270-
raise ValueError(
271-
"OfflineToOnlineTrainer does not support async_collection."
272-
)
269+
raise ValueError("OfflineToOnlineTrainer does not support async_collection.")
273270
log_timings = kwargs.pop("log_timings", False)
274271
auto_log_optim_steps = kwargs.pop("auto_log_optim_steps", True)
275272
batch_size = kwargs.pop("batch_size", None)

torchrl/trainers/algorithms/offline_to_online.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ def register(self, trainer, name: str = "offline_to_online_anneal") -> None:
148148
class OfflineToOnlineTrainer(SACTrainer):
149149
"""A SAC trainer for the offline-pretrain -> online-finetune transition.
150150
151-
See also
152-
:class:`~torchrl.trainers.algorithms.configs.OfflineToOnlineTrainerConfig`
151+
See also :class:`~torchrl.trainers.algorithms.configs.OfflineToOnlineTrainerConfig`
153152
for the Hydra configuration counterpart.
154153
155154
Builds on :class:`~torchrl.trainers.algorithms.SACTrainer`, swapping the

0 commit comments

Comments
 (0)