Skip to content

Commit 19ac5c3

Browse files
committed
change step_function, code cleaning
1 parent 90a2d08 commit 19ac5c3

File tree

1 file changed

+13
-26
lines changed

1 file changed

+13
-26
lines changed

hackatari/core.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -147,41 +147,29 @@ def __init__(
147147

148148
def step(self, *args, **kwargs):
149149
"""
150-
Step through the environment, applying modifications and optional Dopamine-style pooling.
151-
152-
The step logic consists of:
153-
- Applying step modifications before each ALE step.
154-
- Repeating for (frameskip - 1) steps, accumulating rewards.
155-
- On each step, breaking if terminated/truncated.
156-
- If dopamine pooling is enabled, saving each observation.
157-
- Final step: repeat modification application, collect reward, then run post-detection modifications and fill buffers.
158-
- Return processed observation in the required mode ("dqn", "obj", etc). If pooling, pool last two frames by max operation.
159-
160-
Returns:
161-
tuple: (obs, total_reward, terminated, truncated, info)
150+
Take a step in the game environment after altering the ram.
162151
"""
163152
frameskip = self._frameskip
164153
total_reward = 0.0
165154
terminated = truncated = False
166155
if self.dopamine_pooling:
167-
last_two_obs = [] # Grayscale 84x84
168-
last_two_org = [] # RGB
156+
last_two_obs = []
157+
last_two_org = []
169158

170-
# Frame skipping (step through the environment several times per step call)
171-
for _ in range(frameskip - 1):
159+
for i in range(frameskip-1):
172160
for func in self.step_modifs:
173161
func()
174162
obs, reward, terminated, truncated, info = self._env.step(
175163
*args, **kwargs)
176164
total_reward += float(reward)
177165
if terminated or truncated:
178166
break
179-
if self.dopamine_pooling:
180-
last_two_obs.append(cv2.resize(cv2.cvtColor(self.getScreenRGB(
181-
), cv2.COLOR_RGB2GRAY), (84, 84), interpolation=cv2.INTER_AREA))
182-
last_two_org.append(self.getScreenRGB())
183167

184-
# Final step for this overall step()
168+
if self.dopamine_pooling:
169+
last_two_obs.append(cv2.resize(cv2.cvtColor(self.getScreenRGB(
170+
), cv2.COLOR_RGB2GRAY), (84, 84), interpolation=cv2.INTER_AREA))
171+
last_two_org.append(self.getScreenRGB())
172+
185173
for func in self.step_modifs:
186174
func()
187175
obs, reward, terminated, truncated, info = self._env.step(
@@ -194,24 +182,23 @@ def step(self, *args, **kwargs):
194182
func()
195183
self._fill_buffer()
196184

197-
# Prepare returned obs for appropriate mode
198185
if self.obs_mode == "dqn":
199186
obs = np.array(self._state_buffer_dqn)
200187
elif self.obs_mode == "obj":
201188
obs = np.array(self._state_buffer_ns)
202189

203-
# Dopamine-style pooling for last two frames (if enabled)
204190
if self.dopamine_pooling:
205191
last_two_obs.append(cv2.resize(cv2.cvtColor(self.getScreenRGB(
206192
), cv2.COLOR_RGB2GRAY), (84, 84), interpolation=cv2.INTER_AREA))
207193
last_two_org.append(self.getScreenRGB())
208194
merged_obs = np.maximum.reduce(last_two_obs)
209195
merged_org = np.maximum.reduce(last_two_org)
210-
# Update state buffers and output obs
211-
if self.create_dqn_stack and self._state_buffer_dqn is not None:
196+
197+
if self.create_dqn_stack:
212198
self._state_buffer_dqn[-1] = merged_obs
213-
if self.create_rgb_stack and self._state_buffer_rgb is not None:
199+
if self.create_rgb_stack:
214200
self._state_buffer_rgb[-1] = merged_org
201+
215202
if self.obs_mode == "dqn":
216203
obs[-1] = merged_obs
217204
else:

0 commit comments

Comments
 (0)