Skip to content

Commit d95581d

Browse files
Removing TerrainMetadata class - Refactored terrain metadata generation
1 parent 34bbd2e commit d95581d

File tree

1 file changed

+89
-182
lines changed

1 file changed

+89
-182
lines changed

Diff for: gymnasium/envs/box2d/bipedal_walker.py

+89-182
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
__credits__ = ["Andrea PIERRÉ"]
22

33
import math
4-
from copy import deepcopy
54
from typing import TYPE_CHECKING, List, Optional
65

76
import numpy as np
@@ -81,137 +80,6 @@
8180
)
8281

8382

84-
class TerrainMetadata:
85-
"""
86-
## Description
87-
This is metadata object handler for the BipedalWalker environment.
88-
89-
## Raw Example
90-
```python
91-
import gymnasium as gym
92-
93-
env = gym.make("BipedalWalker-v3", hardcore=True, render_mode="human")
94-
env.reset()
95-
96-
bipedal_env = env.unwrapped
97-
metadata = bipedal_env.terrain_metadata()
98-
options = dict(metadata=metadata)
99-
env.reset(options=options)
100-
```
101-
102-
## Designed Example
103-
```python
104-
import gymnasium as gym
105-
106-
OBSTACLES = dict(
107-
down_stairs=dict(state=2, metadata=(-1, 4, 2)),
108-
up_stairs=dict(state=2, metadata=(1, 4, 2)),
109-
small_stump=dict(state=1, metadata=1),
110-
large_stump=dict(state=1, metadata=3),
111-
hole=dict(state=3, metadata=2),
112-
)
113-
114-
env = gym.make("BipedalWalker-v3", hardcore=True, render_mode="human")
115-
metadata = dict(
116-
designed=True,
117-
states=[OBSTACLES["up_stairs"], OBSTACLES["hole"], OBSTACLES["large_stump"]],
118-
x_variations=False,
119-
y_variations=False
120-
)
121-
options = dict(metadata=metadata)
122-
env.reset(options=options)
123-
```
124-
125-
<!-- ## References -->
126-
127-
## Credits
128-
Created by Arthur Plautz Ventura
129-
130-
"""
131-
132-
def __init__(self, metadata: dict = {}):
133-
self._states = [] # Control parameters
134-
self._metadata = {1: [], 2: [], 3: []} # Random values for terrain types
135-
self._y_variations = True
136-
self._x_variations = False
137-
self.__generate = False
138-
139-
if metadata:
140-
self.__from_dict(metadata) # Copy values from existing metadata
141-
else:
142-
self.__generate = True # New values should be generated
143-
144-
@property
145-
def grass_y_variations(self):
146-
return self._y_variations
147-
148-
@property
149-
def grass_x_variations(self):
150-
return self._x_variations
151-
152-
def _pit_length(self, metadata=None):
153-
return 4
154-
155-
def _stairs_length(self, metadata):
156-
_, stair_width, stair_steps = metadata
157-
return stair_width * stair_steps
158-
159-
def _stump_length(self, metadata):
160-
return metadata
161-
162-
def get_obstacles_length(self):
163-
total_length = 0
164-
state_length = {
165-
1: self._stump_length,
166-
2: self._stairs_length,
167-
3: self._pit_length,
168-
}
169-
170-
n_states = len(self._states)
171-
metadata = deepcopy(self._metadata)
172-
for state in self._states:
173-
length_map = state_length[state]
174-
total_length += length_map(metadata[state].pop(0))
175-
return total_length, n_states
176-
177-
def get_dict(self):
178-
return dict(
179-
states=deepcopy(self._states),
180-
metadata=deepcopy(self._metadata),
181-
x_variations=self._x_variations,
182-
y_variations=self._y_variations,
183-
)
184-
185-
def __from_dict(self, metadata: dict):
186-
if metadata.get("designed", False):
187-
self._y_variations = metadata.get("y_variations", False)
188-
self._x_variations = metadata.get("x_variations", True)
189-
for state_obj in metadata.get("states", []):
190-
state = state_obj["state"]
191-
self._states.append(state)
192-
self._metadata[state].append(state_obj["metadata"])
193-
else:
194-
self._states = metadata.get("states", self._states)
195-
self._metadata = metadata.get("metadata", self._metadata)
196-
self._y_variations = metadata.get("y_variations", self._y_variations)
197-
self._x_variations = metadata.get("x_variations", self._x_variations)
198-
199-
def mode(self) -> bool:
200-
return self.__generate
201-
202-
def get_metadata(self, state: int) -> any:
203-
return self._metadata[state].pop(0)
204-
205-
def set_metadata(self, state: int, value: any):
206-
self._metadata[state].append(value)
207-
208-
def get_state(self) -> int:
209-
return self._states.pop(0)
210-
211-
def add_state(self, state: int):
212-
self._states.append(state)
213-
214-
21583
class ContactDetector(contactListener):
21684
def __init__(self, env):
21785
contactListener.__init__(self)
@@ -306,9 +174,15 @@ class BipedalWalker(gym.Env, EzPickle):
306174
"render_fps": FPS,
307175
}
308176

309-
def __init__(self, render_mode: Optional[str] = None, hardcore: bool = False):
177+
def __init__(
178+
self,
179+
render_mode: Optional[str] = None,
180+
hardcore: bool = False,
181+
fall_down_penalty: bool = True,
182+
):
310183
EzPickle.__init__(self, render_mode, hardcore)
311184
self.isopen = True
185+
self.fall_down_penaly = fall_down_penalty
312186

313187
self.world = Box2D.b2World()
314188
self.terrain: List[Box2D.b2Body] = []
@@ -398,7 +272,7 @@ def __init__(self, render_mode: Optional[str] = None, hardcore: bool = False):
398272
self.render_mode = render_mode
399273
self.screen: Optional[pygame.Surface] = None
400274
self.clock = None
401-
self._terrain_metadata = None
275+
self._terrain_metadata = {}
402276

403277
def _destroy(self):
404278
if not self.terrain:
@@ -414,19 +288,79 @@ def _destroy(self):
414288
self.legs = []
415289
self.joints = []
416290

417-
def terrain_metadata(self):
418-
if self._terrain_metadata:
419-
return self._terrain_metadata.get_dict()
291+
def _process_terrain_metadata(self):
292+
STATES = (1, 2, 3)
293+
STUMP, STAIRS, PIT = STATES
420294

421-
def _generate_terrain(self, hardcore):
422-
generate = self._terrain_metadata.mode()
423-
if not generate:
424-
obstacles_length, n_obstacles = (
425-
self._terrain_metadata.get_obstacles_length()
295+
# Defines if the terrain should be saved or copied
296+
self._predefined_terrain = bool(self._terrain_metadata)
297+
# Defines if the length of the grass between obstacles should be randomly distributed
298+
self._terrain_grass_x_variation = self._terrain_metadata.get(
299+
"x_variation", False
300+
)
301+
# Defines if the grass height should randomly vary
302+
self._terrain_grass_y_variation = self._terrain_metadata.get(
303+
"y_variation", False
304+
)
305+
306+
if self._predefined_terrain:
307+
states = self._terrain_metadata.get("states", [])
308+
309+
obstacles_length = []
310+
for state_object in states:
311+
state, metadata = state_object.values()
312+
if state in STATES:
313+
if state == STUMP:
314+
obstacle_length = metadata # Stump metadata is the stump size
315+
elif state == STAIRS:
316+
_, stair_width, stair_steps = metadata
317+
obstacle_length = (
318+
stair_width * stair_steps
319+
) # Stairs total length
320+
elif state == PIT:
321+
obstacle_length = 4 # Default pit x size
322+
obstacles_length.append(obstacle_length)
323+
324+
# Total grass portion of the terrain
325+
self.terrain_grass = (TERRAIN_LENGTH - sum(obstacles_length)) // len(
326+
obstacles_length
426327
)
427-
self.terrain_grass = (TERRAIN_LENGTH - obstacles_length) // n_obstacles
428328
else:
429329
self.terrain_grass = TERRAIN_GRASS
330+
self._terrain_metadata = dict(states=[])
331+
332+
def _generate_terrain_state(self, state: int) -> any:
333+
GRASS, STUMP, STAIRS, PIT, _STATES_ = range(5)
334+
335+
if self._predefined_terrain:
336+
if state == GRASS:
337+
next_state = self._terrain_metadata["states"][0]
338+
return next_state["state"]
339+
else:
340+
next_state = self._terrain_metadata["states"].pop(0)
341+
state_metadata = next_state["metadata"]
342+
343+
else:
344+
if state == GRASS:
345+
next_state = self.np_random.integers(1, _STATES_)
346+
return next_state
347+
elif state == STUMP:
348+
state_metadata = self.np_random.integers(1, 3)
349+
elif state == STAIRS:
350+
stair_height = +1 if self.np_random.random() > 0.5 else -1
351+
stair_width = self.np_random.integers(4, 5)
352+
stair_steps = self.np_random.integers(3, 5)
353+
state_metadata = (stair_height, stair_width, stair_steps)
354+
elif state == PIT:
355+
state_metadata = self.np_random.integers(3, 5)
356+
357+
state_object = dict(state=state, metadata=state_metadata)
358+
self._terrain_metadata["states"].append(state_object)
359+
360+
return state_metadata
361+
362+
def _generate_terrain(self, hardcore):
363+
self._process_terrain_metadata()
430364

431365
GRASS, STUMP, STAIRS, PIT, _STATES_ = range(5)
432366
state = GRASS
@@ -446,19 +380,12 @@ def _generate_terrain(self, hardcore):
446380

447381
if state == GRASS and not oneshot:
448382
velocity = 0.8 * velocity + 0.01 * np.sign(TERRAIN_HEIGHT - y)
449-
if self._terrain_metadata.grass_y_variations and i > TERRAIN_STARTPAD:
383+
if self._terrain_grass_y_variation and i > TERRAIN_STARTPAD:
450384
velocity += self.np_random.uniform(-1, 1) / SCALE # 1
451385
y += velocity
452386

453387
elif state == PIT and oneshot:
454-
if generate:
455-
counter = self.np_random.integers(3, 5)
456-
self._terrain_metadata.set_metadata(state=PIT, value=counter)
457-
else:
458-
counter = self._terrain_metadata.get_metadata(state=PIT)
459-
if not counter:
460-
counter = self.np_random.integers(3, 5)
461-
388+
counter = self._generate_terrain_state(state)
462389
poly = [
463390
(x, y),
464391
(x + TERRAIN_STEP, y),
@@ -485,12 +412,7 @@ def _generate_terrain(self, hardcore):
485412
y -= 4 * TERRAIN_STEP
486413

487414
elif state == STUMP and oneshot:
488-
if generate:
489-
counter = self.np_random.integers(1, 3)
490-
self._terrain_metadata.set_metadata(state=STUMP, value=counter)
491-
else:
492-
counter = self._terrain_metadata.get_metadata(state=STUMP)
493-
415+
counter = self._generate_terrain_state(state)
494416
poly = [
495417
(x, y),
496418
(x + counter * TERRAIN_STEP, y),
@@ -503,17 +425,9 @@ def _generate_terrain(self, hardcore):
503425
self.terrain.append(t)
504426

505427
elif state == STAIRS and oneshot:
506-
if generate:
507-
stair_height = +1 if self.np_random.random() > 0.5 else -1
508-
stair_width = self.np_random.integers(4, 5)
509-
stair_steps = self.np_random.integers(3, 5)
510-
self._terrain_metadata.set_metadata(
511-
state=STAIRS, value=(stair_height, stair_width, stair_steps)
512-
)
513-
else:
514-
stair_height, stair_width, stair_steps = (
515-
self._terrain_metadata.get_metadata(state=STAIRS)
516-
)
428+
stair_height, stair_width, stair_steps = self._generate_terrain_state(
429+
state
430+
)
517431

518432
original_y = y
519433
for s in range(stair_steps):
@@ -550,19 +464,15 @@ def _generate_terrain(self, hardcore):
550464
self.terrain_y.append(y)
551465
counter -= 1
552466
if counter == 0:
553-
if self._terrain_metadata.grass_x_variations:
467+
if self._terrain_grass_x_variation:
554468
counter = self.np_random.integers(
555469
self.terrain_grass / 2, self.terrain_grass
556470
)
557471
else:
558472
counter = self.terrain_grass
559473

560474
if state == GRASS and hardcore:
561-
if generate:
562-
state = self.np_random.integers(1, _STATES_)
563-
self._terrain_metadata.add_state(state)
564-
else:
565-
state = self._terrain_metadata.get_state()
475+
state = self._generate_terrain_state(state)
566476
oneshot = True
567477
else:
568478
state = GRASS
@@ -616,12 +526,8 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
616526
self.scroll = 0.0
617527
self.lidar_render = 0
618528

619-
if options and "metadata" in options.keys():
620-
metadata = options.get("metadata")
621-
self._terrain_metadata = TerrainMetadata(metadata)
622-
else:
623-
self._terrain_metadata = TerrainMetadata()
624-
529+
if options:
530+
self._terrain_metadata = options.get("metadata", {})
625531
self._generate_terrain(self.hardcore)
626532
self._generate_clouds()
627533

@@ -781,7 +687,8 @@ def step(self, action: np.ndarray):
781687

782688
terminated = False
783689
if self.game_over or pos[0] < 0:
784-
reward = -100
690+
if self.fall_down_penaly:
691+
reward = -100
785692
terminated = True
786693
if pos[0] > (TERRAIN_LENGTH - self.terrain_grass) * TERRAIN_STEP:
787694
terminated = True

0 commit comments

Comments
 (0)