Skip to content

Commit 26fb874

Browse files
Improve collector (#125)
* remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com>
1 parent 5599a6d commit 26fb874

File tree

3 files changed

+183
-177
lines changed

3 files changed

+183
-177
lines changed

test/base/env.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,34 @@
1-
import time
21
import gym
2+
import time
33
from gym.spaces.discrete import Discrete
44

55

66
class MyTestEnv(gym.Env):
7-
def __init__(self, size, sleep=0, dict_state=False):
7+
"""This is a "going right" task. The task is to go right ``size`` steps.
8+
"""
9+
10+
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
811
self.size = size
912
self.sleep = sleep
1013
self.dict_state = dict_state
14+
self.ma_rew = ma_rew
1115
self.action_space = Discrete(2)
1216
self.reset()
1317

1418
def reset(self, state=0):
1519
self.done = False
1620
self.index = state
21+
return self._get_dict_state()
22+
23+
def _get_reward(self):
24+
"""Generate a non-scalar reward if ma_rew is True."""
25+
x = int(self.done)
26+
if self.ma_rew > 0:
27+
return [x] * self.ma_rew
28+
return x
29+
30+
def _get_dict_state(self):
31+
"""Generate a dict_state if dict_state is True."""
1732
return {'index': self.index} if self.dict_state else self.index
1833

1934
def step(self, action):
@@ -23,22 +38,13 @@ def step(self, action):
2338
time.sleep(self.sleep)
2439
if self.index == self.size:
2540
self.done = True
26-
if self.dict_state:
27-
return {'index': self.index}, 0, True, {}
28-
else:
29-
return self.index, 0, True, {}
41+
return self._get_dict_state(), self._get_reward(), self.done, {}
3042
if action == 0:
3143
self.index = max(self.index - 1, 0)
32-
if self.dict_state:
33-
return {'index': self.index}, 0, False, {'key': 1, 'env': self}
34-
else:
35-
return self.index, 0, False, {}
44+
return self._get_dict_state(), self._get_reward(), self.done, \
45+
{'key': 1, 'env': self} if self.dict_state else {}
3646
elif action == 1:
3747
self.index += 1
3848
self.done = self.index == self.size
39-
if self.dict_state:
40-
return {'index': self.index}, int(self.done), self.done, \
41-
{'key': 1, 'env': self}
42-
else:
43-
return self.index, int(self.done), self.done, \
44-
{'key': 1, 'env': self}
49+
return self._get_dict_state(), self._get_reward(), \
50+
self.done, {'key': 1, 'env': self}

test/base/test_collector.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ def learn(self):
2727

2828
def preprocess_fn(**kwargs):
2929
# modify info before adding into the buffer
30-
if kwargs.get('info', None) is not None:
30+
# if info is not provided from env, it will be a ``Batch()``.
31+
if not kwargs.get('info', Batch()).is_empty():
3132
n = len(kwargs['obs'])
3233
info = kwargs['info']
3334
for i in range(n):
3435
info[i].update(rew=kwargs['rew'][i])
3536
return {'info': info}
36-
# or
37-
# return Batch(info=info)
37+
# or: return Batch(info=info)
3838
else:
39-
return {}
39+
return Batch()
4040

4141

4242
class Logger(object):
@@ -119,6 +119,48 @@ def test_collector_with_dict_state():
119119
print(batch['obs_next']['index'])
120120

121121

122+
def test_collector_with_ma():
123+
def reward_metric(x):
124+
return x.sum()
125+
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
126+
policy = MyPolicy()
127+
c0 = Collector(policy, env, ReplayBuffer(size=100),
128+
preprocess_fn, reward_metric=reward_metric)
129+
r = c0.collect(n_step=3)['rew']
130+
assert np.asanyarray(r).size == 1 and r == 0.
131+
r = c0.collect(n_episode=3)['rew']
132+
assert np.asanyarray(r).size == 1 and r == 4.
133+
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
134+
for i in [2, 3, 4, 5]]
135+
envs = VectorEnv(env_fns)
136+
c1 = Collector(policy, envs, ReplayBuffer(size=100),
137+
preprocess_fn, reward_metric=reward_metric)
138+
r = c1.collect(n_step=10)['rew']
139+
assert np.asanyarray(r).size == 1 and r == 4.
140+
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
141+
assert np.asanyarray(r).size == 1 and r == 4.
142+
batch = c1.sample(10)
143+
print(batch)
144+
c0.buffer.update(c1.buffer)
145+
obs = [
146+
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
147+
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
148+
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
149+
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
150+
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
151+
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
152+
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
153+
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
154+
[[x] * 4 for x in rew])
155+
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
156+
preprocess_fn, reward_metric=reward_metric)
157+
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
158+
assert np.asanyarray(r).size == 1 and r == 4.
159+
batch = c2.sample(10)
160+
print(batch['obs_next'])
161+
162+
122163
if __name__ == '__main__':
123164
test_collector()
124165
test_collector_with_dict_state()
166+
test_collector_with_ma()

0 commit comments

Comments
 (0)