@@ -158,15 +158,36 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
158158 # we avoid sampling from the last `num_unroll_steps` steps of the game segment.
159159 if pos_in_game_segment >= self ._cfg .game_segment_length - self ._cfg .num_unroll_steps - self ._cfg .td_steps :
160160 pos_in_game_segment = np .random .choice (self ._cfg .game_segment_length - self ._cfg .num_unroll_steps - self ._cfg .td_steps , 1 ).item ()
161- if pos_in_game_segment >= len (game_segment .action_segment ) - 1 :
162- pos_in_game_segment = np .random .choice (len (game_segment .action_segment ) - 1 , 1 ).item ()
161+
162+ segment_len = len (game_segment .action_segment )
163+ if pos_in_game_segment >= segment_len - 1 :
164+ # If the segment is very short (length 0 or 1), we can't randomly sample a position
165+ # before the last one. The only safe position is 0.
166+ if segment_len > 1 :
167+ # If the segment has at least 2 actions, we can safely sample from [0, len-2].
168+ # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct.
169+ pos_in_game_segment = np .random .choice (segment_len - 1 , 1 ).item ()
170+ else :
171+ # If segment length is 0 or 1, the only valid/safe position is 0.
172+ pos_in_game_segment = 0
173+
163174 else :
164175 # For environments with a fixed action space (e.g., Atari),
165176 # we can safely sample from the entire game segment range.
166177 if pos_in_game_segment >= self ._cfg .game_segment_length :
167178 pos_in_game_segment = np .random .choice (self ._cfg .game_segment_length , 1 ).item ()
168- if pos_in_game_segment >= len (game_segment .action_segment ) - 1 :
169- pos_in_game_segment = np .random .choice (len (game_segment .action_segment ) - 1 , 1 ).item ()
179+
180+ segment_len = len (game_segment .action_segment )
181+ if pos_in_game_segment >= segment_len - 1 :
182+ # If the segment is very short (length 0 or 1), we can't randomly sample a position
183+ # before the last one. The only safe position is 0.
184+ if segment_len > 1 :
185+ # If the segment has at least 2 actions, we can safely sample from [0, len-2].
186+ # The upper bound for np.random.choice is exclusive, so (segment_len - 1) is correct.
187+ pos_in_game_segment = np .random .choice (segment_len - 1 , 1 ).item ()
188+ else :
189+ # If segment length is 0 or 1, the only valid/safe position is 0.
190+ pos_in_game_segment = 0
170191
171192 pos_in_game_segment_list .append (pos_in_game_segment )
172193
0 commit comments