Skip to content

Commit 1966254

Browse files
authored
Check that MultiDiscrete.dtype is not None (#1196)
1 parent 175202f commit 1966254

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

Diff for: gymnasium/spaces/multi_discrete.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,19 @@ def __init__(
5959
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
6060
start: Optionally, the starting value the element of each class will take (defaults to 0).
6161
"""
62+
# determine dtype
63+
if dtype is None:
64+
raise ValueError(
65+
"MultiDiscrete dtype must be explicitly provided, cannot be None."
66+
)
67+
self.dtype = np.dtype(dtype)
68+
69+
# * check that dtype is an accepted dtype
70+
if not (np.issubdtype(self.dtype, np.integer)):
71+
raise ValueError(
72+
f"Invalid MultiDiscrete dtype ({self.dtype}), must be an integer dtype"
73+
)
74+
6275
self.nvec = np.array(nvec, dtype=dtype, copy=True)
6376
if start is not None:
6477
self.start = np.array(start, dtype=dtype, copy=True)
@@ -70,7 +83,7 @@ def __init__(
7083
), "start and nvec (counts) should have the same shape"
7184
assert (self.nvec > 0).all(), "nvec (counts) have to be positive"
7285

73-
super().__init__(self.nvec.shape, dtype, seed)
86+
super().__init__(self.nvec.shape, self.dtype, seed)
7487

7588
@property
7689
def shape(self) -> tuple[int, ...]:

0 commit comments

Comments
 (0)