Skip to content

Commit c4f54fc

Browse files
qgallouedecaraffin
andauthored
Handling multi-dimensional action spaces (#971)
* Handle non 1D action shape * Revert changes of observation (out of the scope of this PR) * Apply changes to DictReplayBuffer * Update tests * Rollout buffer n-D actions space handling * Remove error when non 1D action space * ActorCriticPolicy return action with the proper shape * remove useless reshape * Update changelog * Add tests Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 6ce33f5 commit c4f54fc

File tree

5 files changed

+42
-10
lines changed

5 files changed

+42
-10
lines changed

docs/misc/changelog.rst

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Bug Fixes:
1919
^^^^^^^^^^
2020
- Fixed the issue that ``predict`` does not always return action as ``np.ndarray`` (@qgallouedec)
2121
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
22+
- Added multidimensional action space support (@qgallouedec)
2223

2324
Deprecations:
2425
^^^^^^^^^^^^^

stable_baselines3/common/buffers.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,7 @@ def add(
247247
next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)
248248

249249
# Same, for actions
250-
if isinstance(self.action_space, spaces.Discrete):
251-
action = action.reshape((self.n_envs, self.action_dim))
250+
action = action.reshape((self.n_envs, self.action_dim))
252251

253252
# Copy to avoid modification by reference
254253
self.observations[self.pos] = np.array(obs).copy()
@@ -433,6 +432,9 @@ def add(
433432
if isinstance(self.observation_space, spaces.Discrete):
434433
obs = obs.reshape((self.n_envs,) + self.obs_shape)
435434

435+
# Same reshape, for actions
436+
action = action.reshape((self.n_envs, self.action_dim))
437+
436438
self.observations[self.pos] = np.array(obs).copy()
437439
self.actions[self.pos] = np.array(action).copy()
438440
self.rewards[self.pos] = np.array(reward).copy()
@@ -586,8 +588,7 @@ def add(
586588
self.next_observations[key][self.pos] = np.array(next_obs[key]).copy()
587589

588590
# Same reshape, for actions
589-
if isinstance(self.action_space, spaces.Discrete):
590-
action = action.reshape((self.n_envs, self.action_dim))
591+
action = action.reshape((self.n_envs, self.action_dim))
591592

592593
self.actions[self.pos] = np.array(action).copy()
593594
self.rewards[self.pos] = np.array(reward).copy()

stable_baselines3/common/distributions.py

-1
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,6 @@ def make_proba_distribution(
658658
dist_kwargs = {}
659659

660660
if isinstance(action_space, spaces.Box):
661-
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
662661
cls = StateDependentNoiseDistribution if use_sde else DiagGaussianDistribution
663662
return cls(get_action_dim(action_space), **dist_kwargs)
664663
elif isinstance(action_space, spaces.Discrete):

stable_baselines3/common/policies.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -336,8 +336,8 @@ def predict(
336336

337337
with th.no_grad():
338338
actions = self._predict(observation, deterministic=deterministic)
339-
# Convert to numpy
340-
actions = actions.cpu().numpy()
339+
# Convert to numpy, and reshape to the original action shape
340+
actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)
341341

342342
if isinstance(self.action_space, gym.spaces.Box):
343343
if self.squash_output:
@@ -592,6 +592,7 @@ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tenso
592592
distribution = self._get_action_dist_from_latent(latent_pi)
593593
actions = distribution.get_actions(deterministic=deterministic)
594594
log_prob = distribution.log_prob(actions)
595+
actions = actions.reshape((-1,) + self.action_space.shape)
595596
return actions, values, log_prob
596597

597598
def _get_action_dist_from_latent(self, latent_pi: th.Tensor) -> Distribution:

tests/test_spaces.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,19 @@ def step(self, action):
3333
return self.observation_space.sample(), 0.0, False, {}
3434

3535

36+
class DummyMultidimensionalAction(gym.Env):
37+
def __init__(self):
38+
super().__init__()
39+
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
40+
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
41+
42+
def reset(self):
43+
return self.observation_space.sample()
44+
45+
def step(self, action):
46+
return self.observation_space.sample(), 0.0, False, {}
47+
48+
3649
@pytest.mark.parametrize("model_class", [SAC, TD3, DQN])
3750
@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8)])
3851
def test_identity_spaces(model_class, env):
@@ -53,22 +66,39 @@ def test_identity_spaces(model_class, env):
5366

5467

5568
@pytest.mark.parametrize("model_class", [A2C, DDPG, DQN, PPO, SAC, TD3])
56-
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1"])
69+
@pytest.mark.parametrize("env", ["Pendulum-v1", "CartPole-v1", DummyMultidimensionalAction()])
5770
def test_action_spaces(model_class, env):
71+
kwargs = {}
5872
if model_class in [SAC, DDPG, TD3]:
59-
supported_action_space = env == "Pendulum-v1"
73+
supported_action_space = env == "Pendulum-v1" or isinstance(env, DummyMultidimensionalAction)
74+
kwargs["learning_starts"] = 2
75+
kwargs["train_freq"] = 32
6076
elif model_class == DQN:
6177
supported_action_space = env == "CartPole-v1"
6278
elif model_class in [A2C, PPO]:
6379
supported_action_space = True
80+
kwargs["n_steps"] = 64
6481

6582
if supported_action_space:
66-
model_class("MlpPolicy", env)
83+
model = model_class("MlpPolicy", env, **kwargs)
84+
if isinstance(env, DummyMultidimensionalAction):
85+
model.learn(64)
6786
else:
6887
with pytest.raises(AssertionError):
6988
model_class("MlpPolicy", env)
7089

7190

91+
def test_sde_multi_dim():
92+
SAC(
93+
"MlpPolicy",
94+
DummyMultidimensionalAction(),
95+
learning_starts=10,
96+
use_sde=True,
97+
sde_sample_freq=2,
98+
use_sde_at_warmup=True,
99+
).learn(20)
100+
101+
72102
@pytest.mark.parametrize("model_class", [A2C, PPO, DQN])
73103
@pytest.mark.parametrize("env", ["Taxi-v3"])
74104
def test_discrete_obs_space(model_class, env):

0 commit comments

Comments
 (0)