Skip to content

Commit 5069425

Browse files
puyuan1996zjowowen
andauthored
fix(pu): fix pos_in_game_segment bug in buffer (#414)
Co-authored-by: zjowowen <zjowowen@outlook.com>
1 parent 90e44a6 commit 5069425

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

lzero/mcts/buffer/game_buffer.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,36 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
158158
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
159159
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
160160
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item()
161-
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
162-
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
161+
162+
segment_len = len(game_segment.action_segment)
163+
if pos_in_game_segment >= segment_len - 1:
164+
# If the segment is very short (length 0 or 1), we can't randomly sample a position
165+
# before the last one. The only safe position is 0.
166+
if segment_len > 1:
167+
# If the segment has at least 2 actions, we can safely sample from [0, len-2].
168+
# The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct.
169+
pos_in_game_segment = np.random.choice(segment_len - 1, 1).item()
170+
else:
171+
# If segment length is 0 or 1, the only valid/safe position is 0.
172+
pos_in_game_segment = 0
173+
163174
else:
164175
# For environments with a fixed action space (e.g., Atari),
165176
# we can safely sample from the entire game segment range.
166177
if pos_in_game_segment >= self._cfg.game_segment_length:
167178
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
168-
if pos_in_game_segment >= len(game_segment.action_segment) - 1:
169-
pos_in_game_segment = np.random.choice(len(game_segment.action_segment) - 1, 1).item()
179+
180+
segment_len = len(game_segment.action_segment)
181+
if pos_in_game_segment >= segment_len - 1:
182+
# If the segment is very short (length 0 or 1), we can't randomly sample a position
183+
# before the last one. The only safe position is 0.
184+
if segment_len > 1:
185+
# If the segment has at least 2 actions, we can safely sample from [0, len-2].
186+
# The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct.
187+
pos_in_game_segment = np.random.choice(segment_len - 1, 1).item()
188+
else:
189+
# If segment length is 0 or 1, the only valid/safe position is 0.
190+
pos_in_game_segment = 0
170191

171192
pos_in_game_segment_list.append(pos_in_game_segment)
172193

zoo/classic_control/cartpole/config/cartpole_muzero_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
model_path=None,
4444
cuda=True,
4545
env_type='not_board_games',
46-
action_type='varied_action_space',
4746
game_segment_length=50,
4847
update_per_collect=update_per_collect,
4948
batch_size=batch_size,

0 commit comments

Comments
 (0)