Skip to content

Commit c94424b

Browse files
committed
fix
1 parent 56f9975 commit c94424b

1 file changed

Lines changed: 9 additions & 4 deletions

File tree

torchrl/envs/libs/gym.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,10 @@ def convert_multidiscrete_spec(
438438
remap_state_to_observation=None,
439439
batch_size=None,
440440
):
441-
if len(spec.nvec.shape) == 1:
441+
# Only use MultiCategorical/MultiOneHot for heterogeneous nvec (e.g., [3, 5, 7]).
442+
# Homogeneous nvec like [2, 2] typically represents independent actions
443+
# (e.g., vectorized envs with same Discrete(n) per env) and should use stacking.
444+
if len(spec.nvec.shape) == 1 and len(np.unique(spec.nvec)) > 1:
442445
dtype = (
443446
numpy_to_torch_dtype_dict[spec.dtype]
444447
if categorical_action_encoding
@@ -1408,18 +1411,20 @@ def _make_specs(self, env: gym.Env, batch_size=None) -> None: # noqa: F821
14081411
if self._categorical_action_encoding
14091412
else torch.long
14101413
)
1411-
# Flattened categorical: n = product(nvec), shape = mask shape
1414+
# Flattened action: single choice from prod(nvec) options.
1415+
# The mask (which has shape matching nvec) will be reshaped
1416+
# by Categorical/OneHot.update_mask when applied.
14121417
if self._categorical_action_encoding:
14131418
action_spec = Categorical(
14141419
prod_n,
1415-
shape=mask_spec.shape,
1420+
shape=(),
14161421
device=self.device,
14171422
dtype=dtype,
14181423
)
14191424
else:
14201425
action_spec = OneHot(
14211426
prod_n,
1422-
shape=(*mask_spec.shape, prod_n),
1427+
shape=(prod_n,),
14231428
device=self.device,
14241429
dtype=torch.bool,
14251430
)

0 commit comments

Comments
 (0)