@@ -27,16 +27,16 @@ def learn(self):
2727
2828def preprocess_fn (** kwargs ):
2929 # modify info before adding into the buffer
30- if kwargs .get ('info' , None ) is not None :
30+ # if info is not provided from env, it will be a ``Batch()``.
31+ if not kwargs .get ('info' , Batch ()).is_empty ():
3132 n = len (kwargs ['obs' ])
3233 info = kwargs ['info' ]
3334 for i in range (n ):
3435 info [i ].update (rew = kwargs ['rew' ][i ])
3536 return {'info' : info }
36- # or
37- # return Batch(info=info)
37+ # or: return Batch(info=info)
3838 else :
39- return {}
39+ return Batch ()
4040
4141
4242class Logger (object ):
@@ -119,6 +119,48 @@ def test_collector_with_dict_state():
119119 print (batch ['obs_next' ]['index' ])
120120
121121
122+ def test_collector_with_ma ():
123+ def reward_metric (x ):
124+ return x .sum ()
125+ env = MyTestEnv (size = 5 , sleep = 0 , ma_rew = 4 )
126+ policy = MyPolicy ()
127+ c0 = Collector (policy , env , ReplayBuffer (size = 100 ),
128+ preprocess_fn , reward_metric = reward_metric )
129+ r = c0 .collect (n_step = 3 )['rew' ]
130+ assert np .asanyarray (r ).size == 1 and r == 0.
131+ r = c0 .collect (n_episode = 3 )['rew' ]
132+ assert np .asanyarray (r ).size == 1 and r == 4.
133+ env_fns = [lambda x = i : MyTestEnv (size = x , sleep = 0 , ma_rew = 4 )
134+ for i in [2 , 3 , 4 , 5 ]]
135+ envs = VectorEnv (env_fns )
136+ c1 = Collector (policy , envs , ReplayBuffer (size = 100 ),
137+ preprocess_fn , reward_metric = reward_metric )
138+ r = c1 .collect (n_step = 10 )['rew' ]
139+ assert np .asanyarray (r ).size == 1 and r == 4.
140+ r = c1 .collect (n_episode = [2 , 1 , 1 , 2 ])['rew' ]
141+ assert np .asanyarray (r ).size == 1 and r == 4.
142+ batch = c1 .sample (10 )
143+ print (batch )
144+ c0 .buffer .update (c1 .buffer )
145+ obs = [
146+ 0. , 1. , 2. , 3. , 4. , 0. , 1. , 2. , 3. , 4. , 0. , 1. , 2. , 3. , 4. , 0. , 1. ,
147+ 0. , 1. , 2. , 0. , 1. , 0. , 1. , 2. , 3. , 0. , 1. , 2. , 3. , 4. , 0. , 1. , 0. ,
148+ 1. , 2. , 0. , 1. , 0. , 1. , 2. , 3. , 0. , 1. , 2. , 3. , 4. ]
149+ assert np .allclose (c0 .buffer [:len (c0 .buffer )].obs , obs )
150+ rew = [0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 , 0 , 1 ,
151+ 0 , 0 , 1 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 , 0 , 1 , 0 ,
152+ 0 , 1 , 0 , 1 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 ]
153+ assert np .allclose (c0 .buffer [:len (c0 .buffer )].rew ,
154+ [[x ] * 4 for x in rew ])
155+ c2 = Collector (policy , envs , ReplayBuffer (size = 100 , stack_num = 4 ),
156+ preprocess_fn , reward_metric = reward_metric )
157+ r = c2 .collect (n_episode = [0 , 0 , 0 , 10 ])['rew' ]
158+ assert np .asanyarray (r ).size == 1 and r == 4.
159+ batch = c2 .sample (10 )
160+ print (batch ['obs_next' ])
161+
162+
122163if __name__ == '__main__' :
123164 test_collector ()
124165 test_collector_with_dict_state ()
166+ test_collector_with_ma ()
0 commit comments