@@ -153,30 +153,58 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
153153 # [0, game_segment_length - num_unroll_steps] to avoid padded data.
154154
155155 if self ._cfg .action_type == 'varied_action_space' :
156- # For some environments (e.g., Jericho), the action space size may be different.
157- # To ensure we can always unroll `num_unroll_steps` steps starting from the sampled position (without exceeding segment length),
158- # we avoid sampling from the last `num_unroll_steps` steps of the game segment.
159- if pos_in_game_segment >= self ._cfg .game_segment_length - self ._cfg .num_unroll_steps - self ._cfg .td_steps :
160- pos_in_game_segment = np .random .choice (self ._cfg .game_segment_length - self ._cfg .num_unroll_steps - self ._cfg .td_steps , 1 ).item ()
161-
156+ # For varied action space environments (e.g., board games with short game length like TicTacToe)
157+ # We need to handle cases where game_segment_length might be smaller than num_unroll_steps + td_steps
158+ # Strategy: progressively relax sampling constraints to accommodate short games
159+
160+ # Step 1: Calculate ideal sampling upper bound
161+ # Ideally, reserve space for both num_unroll_steps and td_steps to ensure complete trajectories
162+ ideal_bound = self ._cfg .game_segment_length - self ._cfg .num_unroll_steps - self ._cfg .td_steps
163+
164+ # Step 2: Handle different game length scenarios with graceful degradation
165+ if ideal_bound > 0 :
166+ # Case A: Normal/long games - enough space for full unroll + td steps
167+ # This is the standard case for most Atari games
168+ sampling_upper_bound = ideal_bound
169+ else :
170+ # Case B: Short games - need to relax constraints
171+ # Try to at least reserve space for unroll steps (most critical for training)
172+ fallback_bound = self ._cfg .game_segment_length - self ._cfg .num_unroll_steps
173+
174+ if fallback_bound > 0 :
175+ # Can still accommodate unroll steps, though td_steps might need padding
176+ sampling_upper_bound = fallback_bound
177+ else :
178+ # Case C: Very short games (e.g., TicTacToe with 5-9 moves)
179+ # Allow sampling from entire segment length, padding will be applied during unrolling
180+ # This allows sampling from position 0 (beginning of game) when necessary
181+ sampling_upper_bound = self ._cfg .game_segment_length
182+
183+ # Ensure at least 1 to avoid np.random.choice errors
184+ if sampling_upper_bound <= 0 :
185+ sampling_upper_bound = 1
186+
187+ # Step 3: Resample position if it exceeds calculated bound
188+ if pos_in_game_segment >= sampling_upper_bound :
189+ pos_in_game_segment = np .random .choice (sampling_upper_bound , 1 ).item ()
190+
191+ # Step 4: Further adjust based on actual segment length (runtime check)
162192 segment_len = len (game_segment .action_segment )
163- if pos_in_game_segment >= segment_len - 1 :
164- # If the segment is very short (length 0 or 1), we can't randomly sample a position
165- # before the last one. The only safe position is 0.
193+ if pos_in_game_segment >= segment_len :
194+ # Position exceeds actual segment, resample within valid range
166195 if segment_len > 1 :
167- # If the segment has at least 2 actions, we can safely sample from [0, len-2].
168- # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct.
196+ # Sample from [0, segment_len-1] to allow at least 1 step forward
169197 pos_in_game_segment = np .random .choice (segment_len - 1 , 1 ).item ()
170198 else :
171- # If segment length is 0 or 1, the only valid/safe position is 0.
199+ # Segment has 0 or 1 actions, can only use position 0
172200 pos_in_game_segment = 0
173201
174202 else :
175203 # For environments with a fixed action space (e.g., Atari),
176204 # we can safely sample from the entire game segment range.
177205 if pos_in_game_segment >= self ._cfg .game_segment_length :
178206 pos_in_game_segment = np .random .choice (self ._cfg .game_segment_length , 1 ).item ()
179-
207+
180208 segment_len = len (game_segment .action_segment )
181209 if pos_in_game_segment >= segment_len - 1 :
182210 # If the segment is very short (length 0 or 1), we can't randomly sample a position
0 commit comments