@@ -5416,61 +5416,6 @@ def test_isaaclab_reset(self, env):
54165416 assert not r ["next" , "policy" ][r ["next" , "done" ].squeeze (- 1 )].isfinite ().any ()
54175417
54185418
5419- @pytest .mark .skipif (not _has_procgen , reason = "Procgen not found" )
5420- class TestProcgen :
5421- @pytest .mark .parametrize ("envname" , ["coinrun" , "starpilot" ])
5422- def test_procgen_envs_available (self , envname ):
5423- # availability check
5424- assert envname in ProcgenEnv .available_envs
5425-
5426- def test_procgen_invalid_env_raises (self ):
5427- with pytest .raises (ValueError ):
5428- ProcgenEnv ("this_env_does_not_exist" )
5429-
5430- def test_procgen_num_envs_batch_size (self ):
5431- env = ProcgenEnv ("coinrun" , num_envs = 3 )
5432- td = env .reset ()
5433- assert td ["observation" ].shape [0 ] == 3
5434- env .close ()
5435-
5436- def test_procgen_seeding_is_deterministic (self ):
5437- e1 = ProcgenEnv ("coinrun" , num_envs = 2 )
5438- e2 = ProcgenEnv ("coinrun" , num_envs = 2 )
5439- e1 .set_seed (0 )
5440- e2 .set_seed (0 )
5441- t1 = e1 .reset ()
5442- t2 = e2 .reset ()
5443- assert torch .equal (t1 ["observation" ], t2 ["observation" ])
5444- e1 .close ()
5445- e2 .close ()
5446-
5447- def test_procgen_step_keys_and_shapes (self ):
5448- env = ProcgenEnv ("coinrun" , num_envs = 2 )
5449- env .reset ()
5450- td = env .rand_step ()
5451- for k in ("observation" , "reward" , "done" ):
5452- assert k in td
5453- assert td ["observation" ].shape [0 ] == 2
5454- env .close ()
5455-
5456- @pytest .mark .skipif (not _has_procgen , reason = "Procgen not found" )
5457- def test_procgen_env_creation_and_reset (self ):
5458- env = ProcgenEnv ("coinrun" , num_envs = 4 )
5459- td = env .reset ()
5460- # ensure batch size corresponds to num_envs
5461- assert td ["observation" ].shape [0 ] == 4
5462-
5463- @pytest .mark .skipif (not _has_procgen , reason = "Procgen not found" )
5464- def test_procgen_env_step (self ):
5465- env = ProcgenEnv ("coinrun" , num_envs = 2 )
5466- env .reset ()
5467- out = env .rand_step ()
5468- # basic checks on returned tensordict
5469- assert "observation" in out
5470- assert "reward" in out
5471- assert "done" in out
5472-
5473-
54745419@pytest .mark .skipif (not _has_procgen , reason = "Procgen not found" )
54755420class TestProcgen :
54765421 @pytest .mark .parametrize ("envname" , ["coinrun" , "starpilot" ])
0 commit comments