Skip to content

Commit aacf06f

Browse files
author
mavenlin
authored
Type checking Action/State (#265)
## Description - Introduce `TArray` to have better type checking than `Array`. - Fix several bugs that can be detected when test is run in dbg mode. ## Motivation and Context `Array` was a temporary design which is very unsafe. e.g. in atari `state["done"_]` has integer type, while the return value of `game_over()` is subject to change in the thirdparty code. We need automatic type casting rather than manually checking them. - [x] I have NOT raised an issue to propose this change ([required](https://envpool.readthedocs.io/en/latest/pages/contributing.html) for new features and bug fixes) ## Types of changes What types of changes does your code introduce? Put an `x` in all the boxes that apply: - [x] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds core functionality) - [ ] New environment (non-breaking change which adds 3rd-party environment) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation (update in the documentation) - [ ] Example (update in the folder of example) ## Implemented Tasks - [x] TArray. - [x] Utilities to cast `std::vector<Array>` to `Dict<Keys, TupleOfTArray>`. - [x] Fix several bugs. - [x] Replace `State`/`Action` in `Env` to be the type safe version. - [x] Replace all uses of `Array` to `TArray` ## Checklist Go over all the following points, and put an `x` in all the boxes that apply. If you are unsure about any of these, don't hesitate to ask. We are here to help! - [x] I have read the [CONTRIBUTION](https://envpool.readthedocs.io/en/latest/pages/contributing.html) guide (**required**) - [ ] My change requires a change to the documentation. - [x] I have updated the tests accordingly (*required for a bug fix or a new feature*). - [ ] I have updated the documentation accordingly. - [x] I have reformatted the code using `make format` (**required**) - [x] I have checked the code using `make lint` (**required**) - [x] I have ensured `make bazel-test` pass. (**required**)
1 parent 51f4686 commit aacf06f

33 files changed

+280
-194
lines changed

benchmark/test_envpool.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@
8282
env = envpool.make_gym(task_id, **kwargs)
8383
env.async_reset()
8484
env.action_space.seed(args.seed)
85-
action = np.array(
86-
[env.action_space.sample() for _ in range(args.batch_size)]
87-
)
85+
action = np.array([env.action_space.sample() for _ in range(args.batch_size)])
8886
t = time.time()
8987
for _ in tqdm.trange(args.total_step):
9088
info = env.recv()[-1]

benchmark/test_gym.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ def run(env, num_envs, total_step, async_):
3333
)
3434
else:
3535
env = gym.vector.make(
36-
task_id, num_envs, async_, lambda e: wrap_deepmind(
37-
e, episode_life=False, clip_rewards=False, frame_stack=4
38-
)
36+
task_id, num_envs, async_, lambda e:
37+
wrap_deepmind(e, episode_life=False, clip_rewards=False, frame_stack=4)
3938
)
4039
elif env == "mujoco":
4140
task_id = "Ant-v3"

envpool/atari/api_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ def test_spec(self) -> None:
3636
action_num = action_nums[task]
3737
spec = make_spec(task.capitalize() + "-v5")
3838
logging.info(spec)
39-
self.assertEqual(
40-
spec.action_array_spec["action"].maximum + 1, action_num
41-
)
39+
self.assertEqual(spec.action_array_spec["action"].maximum + 1, action_num)
4240
# check dm spec
4341
dm_obs_spec = spec.observation_spec().obs
4442
dm_act_spec = spec.action_spec()
@@ -126,9 +124,7 @@ def test_lowlevel_step(self) -> None:
126124
self.assertEqual(ts.observation.lives.dtype, np.int32)
127125
np.testing.assert_allclose(ts.observation.env_id, np.arange(num_envs))
128126
self.assertEqual(ts.observation.env_id.dtype, np.int32)
129-
np.testing.assert_allclose(
130-
ts.observation.players.env_id.shape, (num_envs,)
131-
)
127+
np.testing.assert_allclose(ts.observation.players.env_id.shape, (num_envs,))
132128
self.assertEqual(ts.observation.players.env_id.dtype, np.int32)
133129
action = {
134130
"env_id": np.arange(num_envs),
@@ -178,9 +174,7 @@ def test_highlevel_step(self) -> None:
178174
self.assertEqual(ts.observation.lives.dtype, np.int32)
179175
np.testing.assert_allclose(ts.observation.env_id, np.arange(num_envs))
180176
self.assertEqual(ts.observation.env_id.dtype, np.int32)
181-
np.testing.assert_allclose(
182-
ts.observation.players.env_id.shape, (num_envs,)
183-
)
177+
np.testing.assert_allclose(ts.observation.players.env_id.shape, (num_envs,))
184178
self.assertEqual(ts.observation.players.env_id.dtype, np.int32)
185179
action = {
186180
"env_id": np.arange(num_envs),

envpool/atari/atari_env_test.cc

Lines changed: 28 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ TEST(AtariEnvTest, GrayScaleMaxPoolOrder) {
3535
ptr1[i * n + j] = j;
3636
}
3737
}
38-
Array col0(Spec<uint8_t>({n, n, 3}));
39-
Array col1(Spec<uint8_t>({n, n, 3}));
40-
Array result(Spec<uint8_t>({n, n, 1}));
38+
TArray col0(Spec<uint8_t>({n, n, 3}));
39+
TArray col1(Spec<uint8_t>({n, n, 3}));
40+
TArray result(Spec<uint8_t>({n, n, 1}));
4141
auto* col0_ptr = static_cast<uint8_t*>(col0.Data());
4242
auto* col1_ptr = static_cast<uint8_t*>(col1.Data());
4343
auto* result_ptr = static_cast<uint8_t*>(result.Data());
@@ -97,19 +97,17 @@ TEST(AtariEnvTest, Seed) {
9797
atari::AtariEnvSpec spec(config);
9898
atari::AtariEnvPool envpool0(spec);
9999
atari::AtariEnvPool envpool1(spec);
100-
Array all_env_ids(Spec<int>({static_cast<int>(batch)}));
100+
TArray all_env_ids(Spec<int>({static_cast<int>(batch)}));
101101
for (std::size_t i = 0; i < batch; ++i) {
102102
all_env_ids[i] = i;
103103
}
104104
envpool0.Reset(all_env_ids);
105105
envpool1.Reset(all_env_ids);
106-
std::vector<Array> raw_action(3);
107-
AtariAction action(&raw_action);
106+
107+
AtariAction action;
108108
for (int i = 0; i < total_iter; ++i) {
109-
auto state_vec0 = envpool0.Recv();
110-
auto state_vec1 = envpool1.Recv();
111-
AtariState state0(&state_vec0);
112-
AtariState state1(&state_vec1);
109+
AtariState state0(envpool0.Recv());
110+
AtariState state1(envpool1.Recv());
113111
EXPECT_EQ(state0["obs"_].Shape(),
114112
std::vector<std::size_t>({batch, 4, 84, 84}));
115113
EXPECT_EQ(state1["obs"_].Shape(),
@@ -128,7 +126,7 @@ TEST(AtariEnvTest, Seed) {
128126
}
129127
action["env_id"_] = state0["info:env_id"_];
130128
action["players.env_id"_] = state0["info:env_id"_];
131-
action["action"_] = Array(Spec<int>({static_cast<int>(batch)}));
129+
action["action"_] = TArray(Spec<int>({static_cast<int>(batch)}));
132130
for (std::size_t j = 0; j < batch; ++j) {
133131
action["action"_][j] = std::rand() % 6;
134132
}
@@ -149,17 +147,15 @@ TEST(AtariEnvTest, MaxEpisodeSteps) {
149147
int total_iter = 100;
150148
atari::AtariEnvSpec spec(config);
151149
atari::AtariEnvPool envpool(spec);
152-
Array all_env_ids(Spec<int>({batch}));
150+
TArray all_env_ids(Spec<int>({batch}));
153151
for (int i = 0; i < batch; ++i) {
154152
all_env_ids[i] = i;
155153
}
156154
envpool.Reset(all_env_ids);
157-
std::vector<Array> raw_action(3);
158-
AtariAction action(&raw_action);
155+
AtariAction action;
159156
int count = 0;
160157
for (int i = 0; i < total_iter; ++i) {
161-
auto state_vec = envpool.Recv();
162-
AtariState state(&state_vec);
158+
AtariState state(envpool.Recv());
163159
auto elapsed_step = state["elapsed_step"_];
164160
for (int j = 0; j < batch; ++j) {
165161
EXPECT_EQ(count, static_cast<int>(elapsed_step[j]));
@@ -169,7 +165,7 @@ TEST(AtariEnvTest, MaxEpisodeSteps) {
169165
}
170166
action["env_id"_] = state["info:env_id"_];
171167
action["players.env_id"_] = state["info:env_id"_];
172-
action["action"_] = Array(Spec<int>({batch}));
168+
action["action"_] = TArray(Spec<int>({batch}));
173169
for (int j = 0; j < batch; ++j) {
174170
action["action"_][j] = 0;
175171
}
@@ -188,18 +184,16 @@ TEST(AtariEnvTest, EpisodicLife) {
188184
config["task"_] = "pong";
189185
atari::AtariEnvSpec spec(config);
190186
atari::AtariEnvPool envpool(spec);
191-
Array all_env_ids(Spec<int>({batch}));
187+
TArray all_env_ids(Spec<int>({batch}));
192188
for (int i = 0; i < batch; ++i) {
193189
all_env_ids[i] = i;
194190
}
195191
envpool.Reset(all_env_ids);
196-
std::vector<Array> raw_action(3);
197-
AtariAction action(&raw_action);
192+
AtariAction action;
198193
std::vector<bool> last_done(batch);
199194
std::vector<int> last_lives(batch);
200195
for (int i = 0; i < total_iter; ++i) {
201-
auto state_vec = envpool.Recv();
202-
AtariState state(&state_vec);
196+
AtariState state(envpool.Recv());
203197
auto done = state["done"_];
204198
auto lives = state["info:lives"_];
205199
for (int j = 0; j < batch; ++j) {
@@ -211,7 +205,7 @@ TEST(AtariEnvTest, EpisodicLife) {
211205
}
212206
action["env_id"_] = state["info:env_id"_];
213207
action["players.env_id"_] = state["info:env_id"_];
214-
action["action"_] = Array(Spec<int>({batch}));
208+
action["action"_] = TArray(Spec<int>({batch}));
215209
for (int j = 0; j < batch; ++j) {
216210
action["action"_][j] = std::rand() % 6;
217211
}
@@ -225,8 +219,7 @@ TEST(AtariEnvTest, EpisodicLife) {
225219
last_lives = std::vector<int>(4);
226220
last_done = std::vector<bool>(4, true);
227221
for (int i = 0; i < total_iter; ++i) {
228-
auto state_vec = envpool2.Recv();
229-
AtariState state(&state_vec);
222+
AtariState state(envpool2.Recv());
230223
auto done = state["done"_];
231224
auto lives = state["info:lives"_];
232225
for (int j = 0; j < batch; ++j) {
@@ -250,7 +243,7 @@ TEST(AtariEnvTest, EpisodicLife) {
250243
}
251244
action["env_id"_] = state["info:env_id"_];
252245
action["players.env_id"_] = state["info:env_id"_];
253-
action["action"_] = Array(Spec<int>({batch}));
246+
action["action"_] = TArray(Spec<int>({batch}));
254247
for (int j = 0; j < batch; ++j) {
255248
action["action"_][j] = i % 4;
256249
}
@@ -271,22 +264,18 @@ TEST(AtariEnvTest, ZeroDiscountOnLifeLoss) {
271264
config["zero_discount_on_life_loss"_] = true;
272265
atari::AtariEnvSpec spec2(config);
273266
atari::AtariEnvPool envpool2(spec2);
274-
Array all_env_ids(Spec<int>({batch}));
267+
TArray all_env_ids(Spec<int>({batch}));
275268
for (int i = 0; i < batch; ++i) {
276269
all_env_ids[i] = i;
277270
}
278271
envpool.Reset(all_env_ids);
279272
envpool2.Reset(all_env_ids);
280-
std::vector<Array> raw_action(3);
281-
AtariAction action(&raw_action);
273+
AtariAction action;
282274
std::vector<bool> last_done(batch, true);
283275
std::vector<int> last_lives(batch);
284276
for (int i = 0; i < total_iter; ++i) {
285-
auto state_vec = envpool.Recv();
286-
auto state_vec2 = envpool2.Recv();
287-
AtariState state(&state_vec);
288-
AtariState state2(&state_vec2);
289-
277+
AtariState state(envpool.Recv());
278+
AtariState state2(envpool2.Recv());
290279
auto done = state["done"_];
291280
auto lives = state["info:lives"_];
292281
auto discount = state["discount"_];
@@ -326,7 +315,7 @@ TEST(AtariEnvTest, ZeroDiscountOnLifeLoss) {
326315
}
327316
action["env_id"_] = state["info:env_id"_];
328317
action["players.env_id"_] = state["info:env_id"_];
329-
action["action"_] = Array(Spec<int>({batch}));
318+
action["action"_] = TArray(Spec<int>({batch}));
330319
for (int j = 0; j < batch; ++j) {
331320
action["action"_][j] = i % 4;
332321
}
@@ -351,22 +340,20 @@ TEST(AtariEnvSpeedTest, Benchmark) {
351340
config["thread_affinity_offset"_] = 0;
352341
atari::AtariEnvSpec spec(config);
353342
atari::AtariEnvPool envpool(spec);
354-
Array all_env_ids(Spec<int>({num_envs}));
343+
TArray all_env_ids(Spec<int>({num_envs}));
355344
for (int i = 0; i < num_envs; ++i) {
356345
all_env_ids[i] = i;
357346
}
358347
envpool.Reset(all_env_ids);
359-
std::vector<Array> raw_action(3);
360-
AtariAction action(&raw_action);
361-
action["action"_] = Array(Spec<int>({batch}));
348+
AtariAction action;
349+
action["action"_] = TArray(Spec<int>({batch}));
362350
for (int j = 0; j < batch; ++j) {
363351
action["action"_][j] = 1;
364352
}
365353
auto start = std::chrono::system_clock::now();
366354
for (int i = 0; i < total_iter; ++i) {
367355
// recv
368-
auto state_vec = envpool.Recv();
369-
AtariState state(&state_vec);
356+
AtariState state(envpool.Recv());
370357
auto env_id = state["info:env_id"_];
371358
// EXPECT_EQ(env_id.Shape(),
372359
// std::vector<std::size_t>({(std::size_t)batch}));

envpool/box2d/box2d_correctness_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,7 @@ def solve_bipedal_walker(
262262
env_id = env_id[~done]
263263
hs = hs[~done]
264264

265-
ah = [
266-
self.heuristic_bipedal_walker_policy(s, h) for s, h in zip(obs, hs)
267-
]
265+
ah = [self.heuristic_bipedal_walker_policy(s, h) for s, h in zip(obs, hs)]
268266
action = np.array([i[0] for i in ah])
269267
hs = np.array([i[1] for i in ah])
270268

0 commit comments

Comments
 (0)