Skip to content

Commit a4f8afa

Browse files
committed
Fix OpenEnv CI collection and formatting
1 parent d1ac215 commit a4f8afa

8 files changed

Lines changed: 35 additions & 14 deletions

File tree

test/libs/test_openenv.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212

1313
import pytest
1414
import torch
15-
from omegaconf import DictConfig
16-
from tensordict import TensorDict
17-
from tensordict.tensorclass import NonTensorData
1815

1916
import torchrl.envs.libs.openenv as openenv_mod
2017
import torchrl.envs.llm.libs.openenv as openenv_chat_mod
18+
from tensordict import TensorDict
19+
from tensordict.tensorclass import NonTensorData
2120
from torchrl.data import LazyStackStorage, ReplayBuffer
2221
from torchrl.data.llm import History
2322
from torchrl.envs.libs.openenv import OpenEnvEnv, OpenEnvWrapper
@@ -26,6 +25,8 @@
2625
from torchrl.modules.llm.policies.common import ChatHistory
2726
from torchrl.objectives.llm.grpo import MCAdvantage
2827

28+
_has_omegaconf = importlib.util.find_spec("omegaconf") is not None
29+
2930

3031
@dataclass
3132
class _StepResult:
@@ -46,6 +47,23 @@ def model_dump(self):
4647
return {"prompt": ["nested", {"value": 1}], "reward": 2.0, "done": True}
4748

4849

50+
class _Config(dict):
51+
def __init__(self, data):
52+
super().__init__((key, self._convert(value)) for key, value in data.items())
53+
54+
@classmethod
55+
def _convert(cls, value):
56+
if isinstance(value, dict):
57+
return cls(value)
58+
return value
59+
60+
def __getattr__(self, name):
61+
try:
62+
return self[name]
63+
except KeyError as err:
64+
raise AttributeError(name) from err
65+
66+
4967
class _SyncOpenEnv:
5068
def __init__(self):
5169
self.connected = False
@@ -226,6 +244,8 @@ def test_rand_step_check_env_specs_and_rollout(self):
226244

227245
class TestOpenEnvGRPO:
228246
def test_make_env_openenv_with_local_fixture(self, monkeypatch):
247+
if not _has_omegaconf:
248+
pytest.skip("omegaconf is required to import the GRPO recipe helpers")
229249
pytest.importorskip("transformers")
230250
pytest.importorskip("openenv")
231251
spec = importlib.util.spec_from_file_location(
@@ -246,7 +266,7 @@ def test_make_env_openenv_with_local_fixture(self, monkeypatch):
246266
"from_env",
247267
staticmethod(lambda name: _TextAction),
248268
)
249-
cfg = DictConfig(
269+
cfg = _Config(
250270
{
251271
"env": {
252272
"dataset": "openenv",

torchrl/data/llm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .history import add_chat_template, ContentBase, History
76
from .dataset import (
87
create_infinite_iterator,
98
get_dataloader,
109
TensorDictTokenizer,
1110
TokenizedDatasetLoader,
1211
)
12+
from .history import add_chat_template, ContentBase, History
1313
from .prompt import PromptData, PromptTensorDictTokenizer
1414
from .reward import PairwiseDataset, RewardData
1515
from .topk import TopKRewardSelector

torchrl/envs/llm/chat.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@
1212
from tensordict import lazy_stack, TensorDictBase
1313
from tensordict.utils import _zip_strict
1414
from torch.utils.data import DataLoader
15-
from torchrl.data.tensor_specs import Composite, NonTensor
1615
from torchrl.data.llm.history import History
17-
from torchrl.envs.common import EnvBase
18-
from torchrl.envs.transforms import TransformedEnv
19-
from torchrl.envs.common import _EnvPostInit
16+
from torchrl.data.tensor_specs import Composite, NonTensor
17+
from torchrl.envs.common import _EnvPostInit, EnvBase
2018
from torchrl.envs.llm.transforms.dataloading import (
2119
DataLoadingPrimer,
2220
RayDataLoadingPrimer,
2321
)
22+
from torchrl.envs.transforms import TransformedEnv
2423
from torchrl.modules.llm.policies.common import ChatHistory, Text, Tokens
2524

2625
if TYPE_CHECKING:

torchrl/envs/llm/datasets/gsm8k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from tensordict.utils import _zip_strict
1515
from torch.utils.data import DataLoader
1616
from torchrl.data import TensorSpec
17-
from torchrl.envs.transforms import StepCounter, Transform
1817

1918
from torchrl.envs.llm.chat import DatasetChatEnv
2019

2120
from torchrl.envs.llm.envs import LLMEnv
2221
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
22+
from torchrl.envs.transforms import StepCounter, Transform
2323

2424
if TYPE_CHECKING:
2525
import transformers

torchrl/envs/llm/datasets/ifeval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
import torch
1212
from tensordict import NonTensorData, NonTensorStack, TensorClass, TensorDict
1313
from torchrl.data import Composite, NonTensor, Unbounded
14-
from torchrl.envs.transforms import StepCounter
1514
from torchrl.envs.llm.chat import DatasetChatEnv
1615
from torchrl.envs.llm.reward.ifeval import IfEvalScorer
16+
from torchrl.envs.transforms import StepCounter
1717

1818
if TYPE_CHECKING:
1919
import transformers

torchrl/envs/llm/libs/openenv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ def _format_observation(self, observation: Any) -> Any:
276276
)
277277

278278
def _wrap_observation(self, observation: Any) -> NonTensorData:
279-
return NonTensorData(observation, batch_size=self.batch_size, device=self.device)
279+
return NonTensorData(
280+
observation, batch_size=self.batch_size, device=self.device
281+
)
280282

281283
def _make_history_message(self, role: str, content: Any) -> History:
282284
return History(

torchrl/envs/llm/reward/gsm8k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase
1212
from tensordict.utils import _zip_strict, is_non_tensor
1313
from torchrl.data import Composite, Unbounded
14-
from torchrl.envs.transforms import Transform
1514
from torchrl.envs.common import EnvBase
15+
from torchrl.envs.transforms import Transform
1616

1717

1818
class GSM8KRewardParser(Transform):

torchrl/envs/llm/transforms/reason.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from torchrl._utils import logger as torchrl_logger
1414

1515
from torchrl.data.llm.history import History
16-
from torchrl.envs.transforms import Transform
1716
from torchrl.envs.common import EnvBase
17+
from torchrl.envs.transforms import Transform
1818

1919

2020
class AddThinkingPrompt(Transform):

0 commit comments

Comments
 (0)