Skip to content

Commit 8c73ca3

Browse files
committed
Merge remote-tracking branch 'ParamThakkar123/add/procgen' into add/procgen
2 parents ca74740 + de55a7a commit 8c73ca3

2 files changed

Lines changed: 1 addition & 56 deletions

File tree

.github/workflows/test-linux-libs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ jobs:
541541
unittests-procgen:
542542
strategy:
543543
matrix:
544-
python_version: ["3.9"]
544+
python_version: ["3.10"]
545545
cuda_arch_version: ["12.8"]
546546
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }}
547547
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

test/test_libs.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5416,61 +5416,6 @@ def test_isaaclab_reset(self, env):
54165416
assert not r["next", "policy"][r["next", "done"].squeeze(-1)].isfinite().any()
54175417

54185418

5419-
@pytest.mark.skipif(not _has_procgen, reason="Procgen not found")
5420-
class TestProcgen:
5421-
@pytest.mark.parametrize("envname", ["coinrun", "starpilot"])
5422-
def test_procgen_envs_available(self, envname):
5423-
# availability check
5424-
assert envname in ProcgenEnv.available_envs
5425-
5426-
def test_procgen_invalid_env_raises(self):
5427-
with pytest.raises(ValueError):
5428-
ProcgenEnv("this_env_does_not_exist")
5429-
5430-
def test_procgen_num_envs_batch_size(self):
5431-
env = ProcgenEnv("coinrun", num_envs=3)
5432-
td = env.reset()
5433-
assert td["observation"].shape[0] == 3
5434-
env.close()
5435-
5436-
def test_procgen_seeding_is_deterministic(self):
5437-
e1 = ProcgenEnv("coinrun", num_envs=2)
5438-
e2 = ProcgenEnv("coinrun", num_envs=2)
5439-
e1.set_seed(0)
5440-
e2.set_seed(0)
5441-
t1 = e1.reset()
5442-
t2 = e2.reset()
5443-
assert torch.equal(t1["observation"], t2["observation"])
5444-
e1.close()
5445-
e2.close()
5446-
5447-
def test_procgen_step_keys_and_shapes(self):
5448-
env = ProcgenEnv("coinrun", num_envs=2)
5449-
env.reset()
5450-
td = env.rand_step()
5451-
for k in ("observation", "reward", "done"):
5452-
assert k in td
5453-
assert td["observation"].shape[0] == 2
5454-
env.close()
5455-
5456-
@pytest.mark.skipif(not _has_procgen, reason="Procgen not found")
5457-
def test_procgen_env_creation_and_reset(self):
5458-
env = ProcgenEnv("coinrun", num_envs=4)
5459-
td = env.reset()
5460-
# ensure batch size corresponds to num_envs
5461-
assert td["observation"].shape[0] == 4
5462-
5463-
@pytest.mark.skipif(not _has_procgen, reason="Procgen not found")
5464-
def test_procgen_env_step(self):
5465-
env = ProcgenEnv("coinrun", num_envs=2)
5466-
env.reset()
5467-
out = env.rand_step()
5468-
# basic checks on returned tensordict
5469-
assert "observation" in out
5470-
assert "reward" in out
5471-
assert "done" in out
5472-
5473-
54745419
@pytest.mark.skipif(not _has_procgen, reason="Procgen not found")
54755420
class TestProcgen:
54765421
@pytest.mark.parametrize("envname", ["coinrun", "starpilot"])

0 commit comments

Comments
 (0)