Skip to content

Commit 1dffcc6

Browse files
authored
Allow AtariPreprocessing non-square observations (#1312)
1 parent a4f8cfb commit 1dffcc6

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

gymnasium/wrappers/atari_preprocessing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
self.game_over = False
143143

144144
_low, _high, _dtype = (0, 1, np.float32) if scale_obs else (0, 255, np.uint8)
145-
_shape = self.screen_size + (1 if grayscale_obs else 3,)
145+
_shape = (self.screen_size[1], self.screen_size[0], 1 if grayscale_obs else 3)
146146
if grayscale_obs and not grayscale_newaxis:
147147
_shape = _shape[:-1] # Remove channel axis
148148
self.observation_space = Box(low=_low, high=_high, shape=_shape, dtype=_dtype)

tests/wrappers/test_atari_preprocessing.py

+11
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
),
4848
(84, 84, 1),
4949
),
50+
(
51+
AtariPreprocessing(
52+
gym.make("ALE/Pong-v5"),
53+
screen_size=(160, 210),
54+
grayscale_obs=False,
55+
frame_skip=1,
56+
noop_max=0,
57+
grayscale_newaxis=True,
58+
),
59+
(210, 160, 3),
60+
),
5061
],
5162
)
5263
def test_atari_preprocessing_grayscale(env, expected_obs_shape):

0 commit comments

Comments
 (0)