forked from Farama-Foundation/Gymnasium
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_time_limit.py
80 lines (61 loc) · 2.31 KB
/
test_time_limit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Test suite for TimeLimit wrapper."""
import pytest
import gymnasium as gym
from gymnasium.envs.classic_control.pendulum import PendulumEnv
from gymnasium.wrappers import TimeLimit
def test_time_limit_reset_info():
env = gym.make("CartPole-v1", disable_env_checker=True)
env = TimeLimit(env, 100)
ob_space = env.observation_space
obs, info = env.reset()
assert ob_space.contains(obs)
assert isinstance(info, dict)
@pytest.mark.parametrize("double_wrap", [False, True])
def test_time_limit_wrapper(double_wrap):
# The pendulum env does not terminate by default
# so we are sure termination is only due to timeout
env = PendulumEnv()
max_episode_length = 20
env = TimeLimit(env, max_episode_length)
if double_wrap:
env = TimeLimit(env, max_episode_length)
env.reset()
terminated, truncated = False, False
n_steps = 0
info = {}
while not (terminated or truncated):
n_steps += 1
_, _, terminated, truncated, info = env.step(env.action_space.sample())
assert n_steps == max_episode_length
assert truncated
@pytest.mark.parametrize("double_wrap", [False, True])
def test_termination_on_last_step(double_wrap):
# Special case: termination at the last timestep
# Truncation due to timeout also happens at the same step
env = PendulumEnv()
def patched_step(_action):
return env.observation_space.sample(), 0.0, True, False, {}
env.step = patched_step
max_episode_length = 1
env = TimeLimit(env, max_episode_length)
if double_wrap:
env = TimeLimit(env, max_episode_length)
env.reset()
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
assert terminated is True
assert truncated is True
def test_max_episode_steps():
env = gym.make("CartPole-v1", disable_env_checker=True)
assert env.spec.max_episode_steps == 500
assert TimeLimit(env, max_episode_steps=10).spec.max_episode_steps == 10
with pytest.raises(
AssertionError,
match="Expect the `max_episode_steps` to be positive, actually: -1",
):
TimeLimit(env, max_episode_steps=-1)
with pytest.raises(
AssertionError,
match="Expect the `max_episode_steps` to be positive, actually: None",
):
TimeLimit(env, max_episode_steps=None)
env.close()