From 7fc466fa6b78960b272743a0096c249ae0ea2b38 Mon Sep 17 00:00:00 2001 From: troiganto Date: Mon, 26 Aug 2024 17:52:37 +0200 Subject: [PATCH] Fix outdated docs for TimeLimit max_episode_steps and add validation. Closes #1147. --- gymnasium/wrappers/common.py | 5 ++++- tests/wrappers/test_time_limit.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/gymnasium/wrappers/common.py b/gymnasium/wrappers/common.py index e66f6f7afd..9b3d225bdb 100644 --- a/gymnasium/wrappers/common.py +++ b/gymnasium/wrappers/common.py @@ -96,8 +96,11 @@ def __init__( Args: env: The environment to apply the wrapper - max_episode_steps: An optional max episode steps (if ``None``, ``env.spec.max_episode_steps`` is used) + max_episode_steps: the environment step after which the episode is truncated (``elapsed >= max_episode_steps``) """ + assert ( + isinstance(max_episode_steps, int) and max_episode_steps > 0 + ), f"Expect the `max_episode_steps` to be positive, actually: {max_episode_steps}" gym.utils.RecordConstructorArgs.__init__( self, max_episode_steps=max_episode_steps ) diff --git a/tests/wrappers/test_time_limit.py b/tests/wrappers/test_time_limit.py index 2d52619638..b5ac8e899e 100644 --- a/tests/wrappers/test_time_limit.py +++ b/tests/wrappers/test_time_limit.py @@ -57,3 +57,24 @@ def patched_step(_action): _, _, terminated, truncated, _ = env.step(env.action_space.sample()) assert terminated is True assert truncated is True + + +def test_max_episode_steps(): + env = gym.make("CartPole-v1", disable_env_checker=True) + + assert env.spec.max_episode_steps == 500 + assert TimeLimit(env, max_episode_steps=10).spec.max_episode_steps == 10 + + with pytest.raises( + AssertionError, + match="Expect the `max_episode_steps` to be positive, actually: -1", + ): + TimeLimit(env, max_episode_steps=-1) + + with pytest.raises( + AssertionError, + match="Expect the `max_episode_steps` to be positive, actually: None", + ): + TimeLimit(env, max_episode_steps=None) + + env.close()