121121 UnityMLAgentsWrapper ,
122122)
123123from torchrl .envs .libs .vmas import _has_vmas , VmasEnv , VmasWrapper
124+ from torchrl .envs .libs .procgen import ProcgenEnv
124125
125126from torchrl .envs .transforms import ActionMask , TransformedEnv
126127from torchrl .envs .utils import check_env_specs , ExplorationType , MarlGroupMapType
@@ -5348,29 +5349,20 @@ def test_isaaclab_reset(self, env):
53485349class TestProcgen :
53495350 @pytest .mark .parametrize ("envname" , ["coinrun" , "starpilot" ])
53505351 def test_procgen_envs_available (self , envname ):
5351- from torchrl .envs .libs .procgen import ProcgenEnv
5352-
53535352 # availability check
53545353 assert envname in ProcgenEnv .available_envs
53555354
53565355 def test_procgen_invalid_env_raises (self ):
5357- from torchrl .envs .libs .procgen import ProcgenEnv
5358- import pytest
5359-
53605356 with pytest .raises (ValueError ):
53615357 ProcgenEnv ("this_env_does_not_exist" )
53625358
53635359 def test_procgen_num_envs_batch_size (self ):
5364- from torchrl .envs .libs .procgen import ProcgenEnv
53655360 env = ProcgenEnv ("coinrun" , num_envs = 3 )
53665361 td = env .reset ()
53675362 assert td ["observation" ].shape [0 ] == 3
53685363 env .close ()
53695364
53705365 def test_procgen_seeding_is_deterministic (self ):
5371- from torchrl .envs .libs .procgen import ProcgenEnv
5372- import torch
5373-
53745366 e1 = ProcgenEnv ("coinrun" , num_envs = 2 )
53755367 e2 = ProcgenEnv ("coinrun" , num_envs = 2 )
53765368 e1 .set_seed (0 )
@@ -5382,7 +5374,6 @@ def test_procgen_seeding_is_deterministic(self):
53825374 e2 .close ()
53835375
53845376 def test_procgen_step_keys_and_shapes (self ):
5385- from torchrl .envs .libs .procgen import ProcgenEnv
53865377 env = ProcgenEnv ("coinrun" , num_envs = 2 )
53875378 env .reset ()
53885379 td = env .rand_step ()
@@ -5393,17 +5384,13 @@ def test_procgen_step_keys_and_shapes(self):
53935384
53945385 @pytest .mark .skipif (not _has_procgen , reason = "Procgen not found" )
53955386 def test_procgen_env_creation_and_reset (self ):
5396- from torchrl .envs .libs .procgen import ProcgenEnv
5397-
53985387 env = ProcgenEnv ("coinrun" , num_envs = 4 )
53995388 td = env .reset ()
54005389 # ensure batch size corresponds to num_envs
54015390 assert td ["observation" ].shape [0 ] == 4
54025391
54035392 @pytest .mark .skipif (not _has_procgen , reason = "Procgen not found" )
54045393 def test_procgen_env_step (self ):
5405- from torchrl .envs .libs .procgen import ProcgenEnv
5406-
54075394 env = ProcgenEnv ("coinrun" , num_envs = 2 )
54085395 env .reset ()
54095396 out = env .rand_step ()
0 commit comments