Skip to content

Commit 49d214d

Browse files
committed
fix(pu): fix sampling strategy for board games with short game length like TicTacToe
1 parent 98e4eb7 commit 49d214d

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

lzero/mcts/buffer/game_buffer.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,30 +153,58 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
153153
# [0, game_segment_length - num_unroll_steps] to avoid padded data.
154154

155155
if self._cfg.action_type == 'varied_action_space':
156-
# For some environments (e.g., Jericho), the action space size may be different.
157-
# To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
158-
# we avoid sampling from the last `num_unroll_steps` steps of the game segment.
159-
if pos_in_game_segment >= self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps:
160-
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps, 1).item()
161-
156+
# For varied action space environments (e.g., board games with short game length like TicTacToe)
157+
# We need to handle cases where game_segment_length might be smaller than num_unroll_steps + td_steps
158+
# Strategy: progressively relax sampling constraints to accommodate short games
159+
160+
# Step 1: Calculate ideal sampling upper bound
161+
# Ideally, reserve space for both num_unroll_steps and td_steps to ensure complete trajectories
162+
ideal_bound = self._cfg.game_segment_length - self._cfg.num_unroll_steps - self._cfg.td_steps
163+
164+
# Step 2: Handle different game length scenarios with graceful degradation
165+
if ideal_bound > 0:
166+
# Case A: Normal/long games - enough space for full unroll + td steps
167+
# This is the standard case for most Atari games
168+
sampling_upper_bound = ideal_bound
169+
else:
170+
# Case B: Short games - need to relax constraints
171+
# Try to at least reserve space for unroll steps (most critical for training)
172+
fallback_bound = self._cfg.game_segment_length - self._cfg.num_unroll_steps
173+
174+
if fallback_bound > 0:
175+
# Can still accommodate unroll steps, though td_steps might need padding
176+
sampling_upper_bound = fallback_bound
177+
else:
178+
# Case C: Very short games (e.g., TicTacToe with 5-9 moves)
179+
# Allow sampling from entire segment length, padding will be applied during unrolling
180+
# This allows sampling from position 0 (beginning of game) when necessary
181+
sampling_upper_bound = self._cfg.game_segment_length
182+
183+
# Ensure at least 1 to avoid np.random.choice errors
184+
if sampling_upper_bound <= 0:
185+
sampling_upper_bound = 1
186+
187+
# Step 3: Resample position if it exceeds calculated bound
188+
if pos_in_game_segment >= sampling_upper_bound:
189+
pos_in_game_segment = np.random.choice(sampling_upper_bound, 1).item()
190+
191+
# Step 4: Further adjust based on actual segment length (runtime check)
162192
segment_len = len(game_segment.action_segment)
163-
if pos_in_game_segment >= segment_len - 1:
164-
# If the segment is very short (length 0 or 1), we can't randomly sample a position
165-
# before the last one. The only safe position is 0.
193+
if pos_in_game_segment >= segment_len:
194+
# Position exceeds actual segment, resample within valid range
166195
if segment_len > 1:
167-
# If the segment has at least 2 actions, we can safely sample from [0, len-2].
168-
# The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct.
196+
# Sample from [0, segment_len-1] to allow at least 1 step forward
169197
pos_in_game_segment = np.random.choice(segment_len - 1, 1).item()
170198
else:
171-
# If segment length is 0 or 1, the only valid/safe position is 0.
199+
# Segment has 0 or 1 actions, can only use position 0
172200
pos_in_game_segment = 0
173201

174202
else:
175203
# For environments with a fixed action space (e.g., Atari),
176204
# we can safely sample from the entire game segment range.
177205
if pos_in_game_segment >= self._cfg.game_segment_length:
178206
pos_in_game_segment = np.random.choice(self._cfg.game_segment_length, 1).item()
179-
207+
180208
segment_len = len(game_segment.action_segment)
181209
if pos_in_game_segment >= segment_len - 1:
182210
# If the segment is very short (length 0 or 1), we can't randomly sample a position

0 commit comments

Comments
 (0)