You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Generates a single random sample from this space.
61
69
62
-
A sample will be chosen uniformly at random with the mask if provided
70
+
A sample will be chosen uniformly at random with the mask if provided, or it will be chosen according to a specified probability distribution if the probability mask is provided.
63
71
64
72
Args:
65
73
mask: An optional mask for if an action can be selected.
66
74
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions.
67
75
If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned.
76
+
probability: An optional probability mask describing the probability of each action being selected.
77
+
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.float64`` where each value is in the range ``[0, 1]`` and the sum of all values is 1.
78
+
If the values do not sum to 1, an exception will be thrown.
68
79
69
80
Returns:
70
81
A sampled integer from the space
71
82
"""
72
-
ifmaskisnotNone:
83
+
ifmaskisnotNoneandprobabilityisnotNone:
84
+
raiseValueError(
85
+
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
86
+
)
87
+
# binary mask sampling
88
+
elifmaskisnotNone:
73
89
assertisinstance(
74
90
mask, np.ndarray
75
-
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
91
+
), f"The expected type of the sample mask is np.ndarray, actual type: {type(mask)}"
76
92
assert (
77
93
mask.dtype==np.int8
78
-
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
94
+
), f"The expected dtype of the sample mask is np.int8, actual dtype: {mask.dtype}"
79
95
assertmask.shape== (
80
96
self.n,
81
-
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
97
+
), f"The expected shape of the sample mask is {(int(self.n),)}, actual shape: {mask.shape}"
98
+
82
99
valid_action_mask=mask==1
83
100
assertnp.all(
84
101
np.logical_or(mask==0, valid_action_mask)
85
-
), f"All values of a mask should be 0 or 1, actual values: {mask}"
102
+
), f"All values of the sample mask should be 0 or 1, actual values: {mask}"
103
+
86
104
ifnp.any(valid_action_mask):
87
105
returnself.start+self.np_random.choice(
88
106
np.where(valid_action_mask)[0]
89
107
)
90
108
else:
91
109
returnself.start
110
+
# probability mask sampling
111
+
elifprobabilityisnotNone:
112
+
assertisinstance(
113
+
probability, np.ndarray
114
+
), f"The expected type of the sample probability is np.ndarray, actual type: {type(probability)}"
115
+
assert (
116
+
probability.dtype==np.float64
117
+
), f"The expected dtype of the sample probability is np.float64, actual dtype: {probability.dtype}"
118
+
assertprobability.shape== (
119
+
self.n,
120
+
), f"The expected shape of the sample probability is {(int(self.n),)}, actual shape: {probability.shape}"
92
121
93
-
returnself.start+self.np_random.integers(self.n)
122
+
assertnp.all(
123
+
np.logical_and(probability>=0, probability<=1)
124
+
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
125
+
assertnp.isclose(
126
+
np.sum(probability), 1
127
+
), f"The sum of the sample probability should be equal to 1, actual sum: {np.sum(probability)}"
0 commit comments