1818)
1919from tianshou .utils import RunningMeanStd
2020
21- if __name__ == ' __main__' :
21+ if __name__ == " __main__" :
2222 from env import MyTestEnv , NXEnv
2323else : # pytest
2424 from test .base .env import MyTestEnv , NXEnv
@@ -80,7 +80,7 @@ def test_async_env(size=10000, num=8, sleep=0.1):
8080 spent_time = time .time ()
8181 while current_idx_start < len (action_list ):
8282 A , B , C , D = v .step (action = act , id = env_ids )
83- b = Batch ({' obs' : A , ' rew' : B , ' done' : C , ' info' : D })
83+ b = Batch ({" obs" : A , " rew" : B , " done" : C , " info" : D })
8484 env_ids = b .info .env_id
8585 o .append (b )
8686 current_idx_start += len (act )
@@ -175,7 +175,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
175175 for info in infos :
176176 assert recurse_comp (infos [0 ], info )
177177
178- if __name__ == ' __main__' :
178+ if __name__ == " __main__" :
179179 t = [0 ] * len (venv )
180180 for i , e in enumerate (venv ):
181181 t [i ] = time .time ()
@@ -186,7 +186,7 @@ def test_vecenv(size=10, num=8, sleep=0.001):
186186 e .reset (np .where (done )[0 ])
187187 t [i ] = time .time () - t [i ]
188188 for i , v in enumerate (venv ):
189- print (f' { type (v )} : { t [i ]:.6f} s' )
189+ print (f" { type (v )} : { t [i ]:.6f} s" )
190190
191191 def assert_get (v , expected ):
192192 assert v .get_env_attr ("size" ) == expected
@@ -242,6 +242,19 @@ def test_env_reset_optional_kwargs(size=10000, num=8):
242242 assert isinstance (info [0 ], dict )
243243
244244
245+ def test_venv_wrapper_gym (num_envs : int = 4 ):
246+ # Issue 697
247+ envs = DummyVectorEnv ([lambda : gym .make ("CartPole-v1" ) for _ in range (num_envs )])
248+ envs = VectorEnvNormObs (envs )
249+ obs_ref = envs .reset (return_info = False )
250+ obs , info = envs .reset (return_info = True )
251+ assert isinstance (obs_ref , np .ndarray )
252+ assert isinstance (obs , np .ndarray )
253+ assert isinstance (info , list )
254+ assert isinstance (info [0 ], dict )
255+ assert obs_ref .shape [0 ] == obs .shape [0 ] == len (info ) == num_envs
256+
257+
245258def run_align_norm_obs (raw_env , train_env , test_env , action_list ):
246259 eps = np .finfo (np .float32 ).eps .item ()
247260 raw_obs , train_obs = [raw_env .reset ()], [train_env .reset ()]
@@ -309,7 +322,7 @@ def __init__(self):
309322 # check conversion is working properly for a batch of actions
310323 np .testing .assert_allclose (
311324 env_m .action (np .array ([env_m .action_space .nvec - 1 ] * bsz )),
312- np .array ([original_act ] * bsz )
325+ np .array ([original_act ] * bsz ),
313326 )
314327 # convert multidiscrete with different action number per
315328 # dimension to discrete action space
@@ -321,7 +334,7 @@ def __init__(self):
321334 # check conversion is working properly for a batch of actions
322335 np .testing .assert_allclose (
323336 env_d .action (np .array ([env_d .action_space .n - 1 ] * bsz )),
324- np .array ([env_m .action_space .nvec - 1 ] * bsz )
337+ np .array ([env_m .action_space .nvec - 1 ] * bsz ),
325338 )
326339
327340
@@ -352,9 +365,11 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
352365 assert v .shape [0 ] == num_envs
353366
354367
355- if __name__ == ' __main__' :
368+ if __name__ == " __main__" :
356369 test_venv_norm_obs ()
370+ test_venv_wrapper_gym ()
357371 test_venv_wrapper_envpool ()
372+ test_venv_wrapper_envpool_gym_reset_return_info ()
358373 test_env_obs_dtype ()
359374 test_vecenv ()
360375 test_attr_unwrapped ()
0 commit comments