Skip to content

Commit ca0d875

Browse files
Changed local to global imports
1 parent 123f8ab commit ca0d875

1 file changed

Lines changed: 1 addition & 14 deletions

File tree

test/test_libs.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
UnityMLAgentsWrapper,
122122
)
123123
from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper
124+
from torchrl.envs.libs.procgen import ProcgenEnv
124125

125126
from torchrl.envs.transforms import ActionMask, TransformedEnv
126127
from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType
@@ -5348,29 +5349,20 @@ def test_isaaclab_reset(self, env):
53485349
class TestProcgen:
53495350
@pytest.mark.parametrize("envname", ["coinrun", "starpilot"])
53505351
def test_procgen_envs_available(self, envname):
5351-
from torchrl.envs.libs.procgen import ProcgenEnv
5352-
53535352
# availability check
53545353
assert envname in ProcgenEnv.available_envs
53555354

53565355
def test_procgen_invalid_env_raises(self):
5357-
from torchrl.envs.libs.procgen import ProcgenEnv
5358-
import pytest
5359-
53605356
with pytest.raises(ValueError):
53615357
ProcgenEnv("this_env_does_not_exist")
53625358

53635359
def test_procgen_num_envs_batch_size(self):
5364-
from torchrl.envs.libs.procgen import ProcgenEnv
53655360
env = ProcgenEnv("coinrun", num_envs=3)
53665361
td = env.reset()
53675362
assert td["observation"].shape[0] == 3
53685363
env.close()
53695364

53705365
def test_procgen_seeding_is_deterministic(self):
5371-
from torchrl.envs.libs.procgen import ProcgenEnv
5372-
import torch
5373-
53745366
e1 = ProcgenEnv("coinrun", num_envs=2)
53755367
e2 = ProcgenEnv("coinrun", num_envs=2)
53765368
e1.set_seed(0)
@@ -5382,7 +5374,6 @@ def test_procgen_seeding_is_deterministic(self):
53825374
e2.close()
53835375

53845376
def test_procgen_step_keys_and_shapes(self):
5385-
from torchrl.envs.libs.procgen import ProcgenEnv
53865377
env = ProcgenEnv("coinrun", num_envs=2)
53875378
env.reset()
53885379
td = env.rand_step()
@@ -5393,17 +5384,13 @@ def test_procgen_step_keys_and_shapes(self):
53935384

53945385
@pytest.mark.skipif(not _has_procgen, reason="Procgen not found")
53955386
def test_procgen_env_creation_and_reset(self):
5396-
from torchrl.envs.libs.procgen import ProcgenEnv
5397-
53985387
env = ProcgenEnv("coinrun", num_envs=4)
53995388
td = env.reset()
54005389
# ensure batch size corresponds to num_envs
54015390
assert td["observation"].shape[0] == 4
54025391

54035392
@pytest.mark.skipif(not _has_procgen, reason="Procgen not found")
54045393
def test_procgen_env_step(self):
5405-
from torchrl.envs.libs.procgen import ProcgenEnv
5406-
54075394
env = ProcgenEnv("coinrun", num_envs=2)
54085395
env.reset()
54095396
out = env.rand_step()

0 commit comments

Comments
 (0)