Skip to content

Commit e4c1f90

Browse files
Add probability masking to space.sample (#1310)
Co-authored-by: Mario Jerez <[email protected]>
1 parent 1dffcc6 commit e4c1f90

21 files changed

+1049
-178
lines changed

gymnasium/spaces/box.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def is_bounded(self, manner: str = "both") -> bool:
342342
f"manner is not in {{'below', 'above', 'both'}}, actual value: {manner}"
343343
)
344344

345-
def sample(self, mask: None = None) -> NDArray[Any]:
345+
def sample(self, mask: None = None, probability: None = None) -> NDArray[Any]:
346346
r"""Generates a single random sample inside the Box.
347347
348348
In creating a sample of the box, each coordinate is sampled (independently) from a distribution
@@ -355,6 +355,7 @@ def sample(self, mask: None = None) -> NDArray[Any]:
355355
356356
Args:
357357
mask: A mask for sampling values from the Box space, currently unsupported.
358+
probability: A probability mask for sampling values from the Box space, currently unsupported.
358359
359360
Returns:
360361
A sampled value from the Box
@@ -363,6 +364,10 @@ def sample(self, mask: None = None) -> NDArray[Any]:
363364
raise gym.error.Error(
364365
f"Box.sample cannot be provided a mask, actual value: {mask}"
365366
)
367+
elif probability is not None:
368+
raise gym.error.Error(
369+
f"Box.sample cannot be provided a probability mask, actual value: {probability}"
370+
)
366371

367372
high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
368373
sample = np.empty(self.shape)

gymnasium/spaces/dict.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -149,27 +149,49 @@ def seed(self, seed: int | dict[str, Any] | None = None) -> dict[str, int]:
149149
f"Expected seed type: dict, int or None, actual type: {type(seed)}"
150150
)
151151

152-
def sample(self, mask: dict[str, Any] | None = None) -> dict[str, Any]:
152+
def sample(
153+
self,
154+
mask: dict[str, Any] | None = None,
155+
probability: dict[str, Any] | None = None,
156+
) -> dict[str, Any]:
153157
"""Generates a single random sample from this space.
154158
155159
The sample is an ordered dictionary of independent samples from the constituent spaces.
156160
157161
Args:
158162
mask: An optional mask for each of the subspaces, expects the same keys as the space
163+
probability: An optional probability mask for each of the subspaces, expects the same keys as the space
159164
160165
Returns:
161166
A dictionary with the same key and sampled values from :attr:`self.spaces`
162167
"""
163-
if mask is not None:
168+
if mask is not None and probability is not None:
169+
raise ValueError(
170+
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
171+
)
172+
elif mask is not None:
164173
assert isinstance(
165174
mask, dict
166-
), f"Expects mask to be a dict, actual type: {type(mask)}"
175+
), f"Expected sample mask to be a dict, actual type: {type(mask)}"
167176
assert (
168177
mask.keys() == self.spaces.keys()
169-
), f"Expect mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
178+
), f"Expected sample mask keys to be same as space keys, mask keys: {mask.keys()}, space keys: {self.spaces.keys()}"
179+
170180
return {k: space.sample(mask=mask[k]) for k, space in self.spaces.items()}
181+
elif probability is not None:
182+
assert isinstance(
183+
probability, dict
184+
), f"Expected sample probability mask to be a dict, actual type: {type(probability)}"
185+
assert (
186+
probability.keys() == self.spaces.keys()
187+
), f"Expected sample probability mask keys to be same as space keys, mask keys: {probability.keys()}, space keys: {self.spaces.keys()}"
171188

172-
return {k: space.sample() for k, space in self.spaces.items()}
189+
return {
190+
k: space.sample(probability=probability[k])
191+
for k, space in self.spaces.items()
192+
}
193+
else:
194+
return {k: space.sample() for k, space in self.spaces.items()}
173195

174196
def contains(self, x: Any) -> bool:
175197
"""Return boolean specifying if x is a valid member of this space."""

gymnasium/spaces/discrete.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ class Discrete(Space[np.int64]):
2222
>>> observation_space = Discrete(3, start=-1, seed=42) # {-1, 0, 1}
2323
>>> observation_space.sample()
2424
np.int64(-1)
25+
>>> observation_space.sample(mask=np.array([0,0,1], dtype=np.int8))
26+
np.int64(1)
27+
>>> observation_space.sample(probability=np.array([0,0,1], dtype=np.float64))
28+
np.int64(1)
29+
>>> observation_space.sample(probability=np.array([0,0.3,0.7], dtype=np.float64))
30+
np.int64(1)
2531
"""
2632

2733
def __init__(
@@ -56,41 +62,74 @@ def is_np_flattenable(self):
5662
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
5763
return True
5864

59-
def sample(self, mask: MaskNDArray | None = None) -> np.int64:
65+
def sample(
66+
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
67+
) -> np.int64:
6068
"""Generates a single random sample from this space.
6169
62-
A sample will be chosen uniformly at random with the mask if provided
70+
A sample will be chosen uniformly at random with the mask if provided, or it will be chosen according to a specified probability distribution if the probability mask is provided.
6371
6472
Args:
6573
mask: An optional mask for if an action can be selected.
6674
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.int8`` where ``1`` represents valid actions and ``0`` invalid / infeasible actions.
6775
If there are no possible actions (i.e. ``np.all(mask == 0)``) then ``space.start`` will be returned.
76+
probability: An optional probability mask describing the probability of each action being selected.
77+
Expected `np.ndarray` of shape ``(n,)`` and dtype ``np.float64`` where each value is in the range ``[0, 1]`` and the sum of all values is 1.
78+
If the values do not sum to 1, an exception will be thrown.
6879
6980
Returns:
7081
A sampled integer from the space
7182
"""
72-
if mask is not None:
83+
if mask is not None and probability is not None:
84+
raise ValueError(
85+
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
86+
)
87+
# binary mask sampling
88+
elif mask is not None:
7389
assert isinstance(
7490
mask, np.ndarray
75-
), f"The expected type of the mask is np.ndarray, actual type: {type(mask)}"
91+
), f"The expected type of the sample mask is np.ndarray, actual type: {type(mask)}"
7692
assert (
7793
mask.dtype == np.int8
78-
), f"The expected dtype of the mask is np.int8, actual dtype: {mask.dtype}"
94+
), f"The expected dtype of the sample mask is np.int8, actual dtype: {mask.dtype}"
7995
assert mask.shape == (
8096
self.n,
81-
), f"The expected shape of the mask is {(self.n,)}, actual shape: {mask.shape}"
97+
), f"The expected shape of the sample mask is {(int(self.n),)}, actual shape: {mask.shape}"
98+
8299
valid_action_mask = mask == 1
83100
assert np.all(
84101
np.logical_or(mask == 0, valid_action_mask)
85-
), f"All values of a mask should be 0 or 1, actual values: {mask}"
102+
), f"All values of the sample mask should be 0 or 1, actual values: {mask}"
103+
86104
if np.any(valid_action_mask):
87105
return self.start + self.np_random.choice(
88106
np.where(valid_action_mask)[0]
89107
)
90108
else:
91109
return self.start
110+
# probability mask sampling
111+
elif probability is not None:
112+
assert isinstance(
113+
probability, np.ndarray
114+
), f"The expected type of the sample probability is np.ndarray, actual type: {type(probability)}"
115+
assert (
116+
probability.dtype == np.float64
117+
), f"The expected dtype of the sample probability is np.float64, actual dtype: {probability.dtype}"
118+
assert probability.shape == (
119+
self.n,
120+
), f"The expected shape of the sample probability is {(int(self.n),)}, actual shape: {probability.shape}"
92121

93-
return self.start + self.np_random.integers(self.n)
122+
assert np.all(
123+
np.logical_and(probability >= 0, probability <= 1)
124+
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
125+
assert np.isclose(
126+
np.sum(probability), 1
127+
), f"The sum of the sample probability should be equal to 1, actual sum: {np.sum(probability)}"
128+
129+
return self.start + self.np_random.choice(np.arange(self.n), p=probability)
130+
# uniform sampling
131+
else:
132+
return self.start + self.np_random.integers(self.n)
94133

95134
def contains(self, x: Any) -> bool:
96135
"""Return boolean specifying if x is a valid member of this space."""

gymnasium/spaces/graph.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@ def sample(
183183
NDArray[Any] | tuple[Any, ...] | None,
184184
]
185185
) = None,
186+
probability: None | (
187+
tuple[
188+
NDArray[Any] | tuple[Any, ...] | None,
189+
NDArray[Any] | tuple[Any, ...] | None,
190+
]
191+
) = None,
186192
num_nodes: int = 10,
187193
num_edges: int | None = None,
188194
) -> GraphInstance:
@@ -192,6 +198,9 @@ def sample(
192198
mask: An optional tuple of optional node and edge mask that is only possible with Discrete spaces
193199
(Box spaces don't support sample masks).
194200
If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
201+
probability: An optional tuple of optional node and edge probability mask that is only possible with Discrete spaces
202+
(Box spaces don't support sample probability masks).
203+
If no ``num_edges`` is provided then the ``edge_mask`` is multiplied by the number of edges
195204
num_nodes: The number of nodes that will be sampled, the default is `10` nodes
196205
num_edges: An optional number of edges, otherwise, a random number between `0` and :math:`num_nodes^2`
197206
@@ -202,10 +211,18 @@ def sample(
202211
num_nodes > 0
203212
), f"The number of nodes is expected to be greater than 0, actual value: {num_nodes}"
204213

205-
if mask is not None:
214+
if mask is not None and probability is not None:
215+
raise ValueError(
216+
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
217+
)
218+
elif mask is not None:
206219
node_space_mask, edge_space_mask = mask
220+
mask_type = "mask"
221+
elif probability is not None:
222+
node_space_mask, edge_space_mask = probability
223+
mask_type = "probability"
207224
else:
208-
node_space_mask, edge_space_mask = None, None
225+
node_space_mask = edge_space_mask = mask_type = None
209226

210227
# we only have edges when we have at least 2 nodes
211228
if num_edges is None:
@@ -228,15 +245,19 @@ def sample(
228245
assert num_edges is not None
229246

230247
sampled_node_space = self._generate_sample_space(self.node_space, num_nodes)
248+
assert sampled_node_space is not None
231249
sampled_edge_space = self._generate_sample_space(self.edge_space, num_edges)
232250

233-
assert sampled_node_space is not None
234-
sampled_nodes = sampled_node_space.sample(node_space_mask)
235-
sampled_edges = (
236-
sampled_edge_space.sample(edge_space_mask)
237-
if sampled_edge_space is not None
238-
else None
239-
)
251+
if mask_type is not None:
252+
node_sample_kwargs = {mask_type: node_space_mask}
253+
edge_sample_kwargs = {mask_type: edge_space_mask}
254+
else:
255+
node_sample_kwargs = edge_sample_kwargs = {}
256+
257+
sampled_nodes = sampled_node_space.sample(**node_sample_kwargs)
258+
sampled_edges = None
259+
if sampled_edge_space is not None:
260+
sampled_edges = sampled_edge_space.sample(**edge_sample_kwargs)
240261

241262
sampled_edge_links = None
242263
if sampled_edges is not None and num_edges > 0:

gymnasium/spaces/multi_binary.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,29 @@ def is_np_flattenable(self):
5959
"""Checks whether this space can be flattened to a :class:`spaces.Box`."""
6060
return True
6161

62-
def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
62+
def sample(
63+
self, mask: MaskNDArray | None = None, probability: MaskNDArray | None = None
64+
) -> NDArray[np.int8]:
6365
"""Generates a single random sample from this space.
6466
6567
A sample is drawn by independent, fair coin tosses (one toss per binary variable of the space).
6668
6769
Args:
68-
mask: An optional np.ndarray to mask samples with expected shape of ``space.shape``.
69-
For ``mask == 0`` then the samples will be ``0`` and ``mask == 1` then random samples will be generated.
70+
mask: An optional ``np.ndarray`` to mask samples with expected shape of ``space.shape``.
71+
For ``mask == 0`` then the samples will be ``0``, for a ``mask == 1`` then the samples will be ``1``.
72+
For random samples, using a mask value of ``2``.
7073
The expected mask shape is the space shape and mask dtype is ``np.int8``.
74+
probability: An optional ``np.ndarray`` to mask samples with expected shape of space.shape where each element
75+
represents the probability of the corresponding sample element being a 1.
76+
The expected mask shape is the space shape and mask dtype is ``np.float64``.
7177
7278
Returns:
7379
Sampled values from space
7480
"""
81+
if mask is not None and probability is not None:
82+
raise ValueError(
83+
f"Only one of `mask` or `probability` can be provided, actual values: mask={mask}, probability={probability}"
84+
)
7585
if mask is not None:
7686
assert isinstance(
7787
mask, np.ndarray
@@ -91,8 +101,25 @@ def sample(self, mask: MaskNDArray | None = None) -> NDArray[np.int8]:
91101
self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype),
92102
mask.astype(self.dtype),
93103
)
104+
elif probability is not None:
105+
assert isinstance(
106+
probability, np.ndarray
107+
), f"The expected type of the probability is np.ndarray, actual type: {type(probability)}"
108+
assert (
109+
probability.dtype == np.float64
110+
), f"The expected dtype of the probability is np.float64, actual dtype: {probability.dtype}"
111+
assert (
112+
probability.shape == self.shape
113+
), f"The expected shape of the probability is {self.shape}, actual shape: {probability}"
114+
assert np.all(
115+
np.logical_and(probability >= 0, probability <= 1)
116+
), f"All values of the sample probability should be between 0 and 1, actual values: {probability}"
94117

95-
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
118+
return (self.np_random.random(size=self.shape) <= probability).astype(
119+
self.dtype
120+
)
121+
else:
122+
return self.np_random.integers(low=0, high=2, size=self.n, dtype=self.dtype)
96123

97124
def contains(self, x: Any) -> bool:
98125
"""Return boolean specifying if x is a valid member of this space."""

0 commit comments

Comments
 (0)