@@ -147,41 +147,29 @@ def __init__(
147147
148148 def step (self , * args , ** kwargs ):
149149 """
150- Step through the environment, applying modifications and optional Dopamine-style pooling.
151-
152- The step logic consists of:
153- - Applying step modifications before each ALE step.
154- - Repeating for (frameskip - 1) steps, accumulating rewards.
155- - On each step, breaking if terminated/truncated.
156- - If dopamine pooling is enabled, saving each observation.
157- - Final step: repeat modification application, collect reward, then run post-detection modifications and fill buffers.
158- - Return processed observation in the required mode ("dqn", "obj", etc). If pooling, pool last two frames by max operation.
159-
160- Returns:
161- tuple: (obs, total_reward, terminated, truncated, info)
150+ Take a step in the game environment after altering the ram.
162151 """
163152 frameskip = self ._frameskip
164153 total_reward = 0.0
165154 terminated = truncated = False
166155 if self .dopamine_pooling :
167- last_two_obs = [] # Grayscale 84x84
168- last_two_org = [] # RGB
156+ last_two_obs = []
157+ last_two_org = []
169158
170- # Frame skipping (step through the environment several times per step call)
171- for _ in range (frameskip - 1 ):
159+ for i in range (frameskip - 1 ):
172160 for func in self .step_modifs :
173161 func ()
174162 obs , reward , terminated , truncated , info = self ._env .step (
175163 * args , ** kwargs )
176164 total_reward += float (reward )
177165 if terminated or truncated :
178166 break
179- if self .dopamine_pooling :
180- last_two_obs .append (cv2 .resize (cv2 .cvtColor (self .getScreenRGB (
181- ), cv2 .COLOR_RGB2GRAY ), (84 , 84 ), interpolation = cv2 .INTER_AREA ))
182- last_two_org .append (self .getScreenRGB ())
183167
184- # Final step for this overall step()
168+ if self .dopamine_pooling :
169+ last_two_obs .append (cv2 .resize (cv2 .cvtColor (self .getScreenRGB (
170+ ), cv2 .COLOR_RGB2GRAY ), (84 , 84 ), interpolation = cv2 .INTER_AREA ))
171+ last_two_org .append (self .getScreenRGB ())
172+
185173 for func in self .step_modifs :
186174 func ()
187175 obs , reward , terminated , truncated , info = self ._env .step (
@@ -194,24 +182,23 @@ def step(self, *args, **kwargs):
194182 func ()
195183 self ._fill_buffer ()
196184
197- # Prepare returned obs for appropriate mode
198185 if self .obs_mode == "dqn" :
199186 obs = np .array (self ._state_buffer_dqn )
200187 elif self .obs_mode == "obj" :
201188 obs = np .array (self ._state_buffer_ns )
202189
203- # Dopamine-style pooling for last two frames (if enabled)
204190 if self .dopamine_pooling :
205191 last_two_obs .append (cv2 .resize (cv2 .cvtColor (self .getScreenRGB (
206192 ), cv2 .COLOR_RGB2GRAY ), (84 , 84 ), interpolation = cv2 .INTER_AREA ))
207193 last_two_org .append (self .getScreenRGB ())
208194 merged_obs = np .maximum .reduce (last_two_obs )
209195 merged_org = np .maximum .reduce (last_two_org )
210- # Update state buffers and output obs
211- if self .create_dqn_stack and self . _state_buffer_dqn is not None :
196+
197+ if self .create_dqn_stack :
212198 self ._state_buffer_dqn [- 1 ] = merged_obs
213- if self .create_rgb_stack and self . _state_buffer_rgb is not None :
199+ if self .create_rgb_stack :
214200 self ._state_buffer_rgb [- 1 ] = merged_org
201+
215202 if self .obs_mode == "dqn" :
216203 obs [- 1 ] = merged_obs
217204 else :
0 commit comments