File tree 1 file changed +14
-1
lines changed
1 file changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -59,6 +59,19 @@ def __init__(
59
59
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space.
60
60
start: Optionally, the starting value the element of each class will take (defaults to 0).
61
61
"""
62
+ # determine dtype
63
+ if dtype is None :
64
+ raise ValueError (
65
+ "MultiDiscrete dtype must be explicitly provided, cannot be None."
66
+ )
67
+ self .dtype = np .dtype (dtype )
68
+
69
+ # * check that dtype is an accepted dtype
70
+ if not (np .issubdtype (self .dtype , np .integer )):
71
+ raise ValueError (
72
+ f"Invalid MultiDiscrete dtype ({ self .dtype } ), must be an integer dtype"
73
+ )
74
+
62
75
self .nvec = np .array (nvec , dtype = dtype , copy = True )
63
76
if start is not None :
64
77
self .start = np .array (start , dtype = dtype , copy = True )
@@ -70,7 +83,7 @@ def __init__(
70
83
), "start and nvec (counts) should have the same shape"
71
84
assert (self .nvec > 0 ).all (), "nvec (counts) have to be positive"
72
85
73
- super ().__init__ (self .nvec .shape , dtype , seed )
86
+ super ().__init__ (self .nvec .shape , self . dtype , seed )
74
87
75
88
@property
76
89
def shape (self ) -> tuple [int , ...]:
You can’t perform that action at this time.
0 commit comments