Skip to content

Commit 0f59e38

Browse files
authored
Fix venv wrapper reset retval error with gym env (#712)
* Fix venv wrapper reset retval error with gym env * fix lint
1 parent f270e88 commit 0f59e38

File tree

4 files changed

+35
-20
lines changed

4 files changed

+35
-20
lines changed

test/base/test_env.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919
from tianshou.utils import RunningMeanStd
2020

21-
if __name__ == '__main__':
21+
if __name__ == "__main__":
2222
from env import MyTestEnv, NXEnv
2323
else: # pytest
2424
from test.base.env import MyTestEnv, NXEnv
@@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
8080
spent_time = time.time()
8181
while current_idx_start < len(action_list):
8282
A, B, C, D = v.step(action=act, id=env_ids)
83-
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
83+
b = Batch({"obs": A, "rew": B, "done": C, "info": D})
8484
env_ids = b.info.env_id
8585
o.append(b)
8686
current_idx_start += len(act)
@@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
175175
for info in infos:
176176
assert recurse_comp(infos[0], info)
177177

178-
if __name__ == '__main__':
178+
if __name__ == "__main__":
179179
t = [0] * len(venv)
180180
for i, e in enumerate(venv):
181181
t[i] = time.time()
@@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
186186
e.reset(np.where(done)[0])
187187
t[i] = time.time() - t[i]
188188
for i, v in enumerate(venv):
189-
print(f'{type(v)}: {t[i]:.6f}s')
189+
print(f"{type(v)}: {t[i]:.6f}s")
190190

191191
def assert_get(v, expected):
192192
assert v.get_env_attr("size") == expected
@@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
242242
assert isinstance(info[0], dict)
243243

244244

245+
def test_venv_wrapper_gym(num_envs: int = 4):
246+
# Issue 697
247+
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(num_envs)])
248+
envs = VectorEnvNormObs(envs)
249+
obs_ref = envs.reset(return_info=False)
250+
obs, info = envs.reset(return_info=True)
251+
assert isinstance(obs_ref, np.ndarray)
252+
assert isinstance(obs, np.ndarray)
253+
assert isinstance(info, list)
254+
assert isinstance(info[0], dict)
255+
assert obs_ref.shape[0] == obs.shape[0] == len(info) == num_envs
256+
257+
245258
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
246259
eps = np.finfo(np.float32).eps.item()
247260
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
@@ -309,7 +322,7 @@ def __init__(self):
309322
# check conversion is working properly for a batch of actions
310323
np.testing.assert_allclose(
311324
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
312-
np.array([original_act] * bsz)
325+
np.array([original_act] * bsz),
313326
)
314327
# convert multidiscrete with different action number per
315328
# dimension to discrete action space
@@ -321,7 +334,7 @@ def __init__(self):
321334
# check conversion is working properly for a batch of actions
322335
np.testing.assert_allclose(
323336
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
324-
np.array([env_m.action_space.nvec - 1] * bsz)
337+
np.array([env_m.action_space.nvec - 1] * bsz),
325338
)
326339

327340

@@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
352365
assert v.shape[0] == num_envs
353366

354367

355-
if __name__ == '__main__':
368+
if __name__ == "__main__":
356369
test_venv_norm_obs()
370+
test_venv_wrapper_gym()
357371
test_venv_wrapper_envpool()
372+
test_venv_wrapper_envpool_gym_reset_return_info()
358373
test_env_obs_dtype()
359374
test_vecenv()
360375
test_attr_unwrapped()

tianshou/data/collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
135135
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
136136
rval = self.env.reset(**gym_reset_kwargs)
137137
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
138-
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
138+
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
139139
)
140140
if returns_info:
141141
obs, info = rval
@@ -173,7 +173,7 @@ def _reset_env_with_ids(
173173
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
174174
rval = self.env.reset(global_ids, **gym_reset_kwargs)
175175
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
176-
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
176+
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
177177
)
178178
if returns_info:
179179
obs_reset, info = rval

tianshou/env/venv_wrappers.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def reset(
4141
self,
4242
id: Optional[Union[int, List[int], np.ndarray]] = None,
4343
**kwargs: Any,
44-
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
44+
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
4545
return self.venv.reset(id, **kwargs)
4646

4747
def step(
@@ -84,15 +84,15 @@ def reset(
8484
self,
8585
id: Optional[Union[int, List[int], np.ndarray]] = None,
8686
**kwargs: Any,
87-
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
88-
retval = self.venv.reset(id, **kwargs)
89-
reset_returns_info = isinstance(
90-
retval, (tuple, list)
91-
) and len(retval) == 2 and isinstance(retval[1], dict)
92-
if reset_returns_info:
93-
obs, info = retval
87+
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
88+
rval = self.venv.reset(id, **kwargs)
89+
returns_info = isinstance(rval, (tuple, list)) and (len(rval) == 2) and (
90+
isinstance(rval[1], dict) or isinstance(rval[1][0], dict)
91+
)
92+
if returns_info:
93+
obs, info = rval
9494
else:
95-
obs = retval
95+
obs = rval
9696

9797
if isinstance(obs, tuple):
9898
raise TypeError(
@@ -103,7 +103,7 @@ def reset(
103103
if self.obs_rms and self.update_obs_rms:
104104
self.obs_rms.update(obs)
105105
obs = self._norm_obs(obs)
106-
if reset_returns_info:
106+
if returns_info:
107107
return obs, info
108108
else:
109109
return obs

tianshou/env/venvs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def reset(
185185
self,
186186
id: Optional[Union[int, List[int], np.ndarray]] = None,
187187
**kwargs: Any,
188-
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
188+
) -> Union[np.ndarray, Tuple[np.ndarray, Union[dict, List[dict]]]]:
189189
"""Reset the state of some envs and return initial observations.
190190
191191
If id is None, reset the state of all the environments and return

0 commit comments

Comments
 (0)