Skip to content

Commit 05d6d82

Browse files
committed
Remove outdated 3.8 and 3.9 references
Add ToArray wrapper to docs Add array-api-extra for reliable testing Fix smaller mistakes in the docs Add NumPy version check for Array API wrappers
1 parent 9556549 commit 05d6d82

File tree

6 files changed

+39
-28
lines changed

6 files changed

+39
-28
lines changed

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ To install the base Gymnasium library, use `pip install gymnasium`
3131

3232
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies.
3333

34-
We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
34+
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
3535

3636
## API
3737

Diff for: docs/api/wrappers/misc_wrappers.md

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ title: Misc Wrappers
2727
## Data Conversion Wrappers
2828

2929
```{eval-rst}
30+
.. autoclass:: gymnasium.wrappers.ToArray
3031
.. autoclass:: gymnasium.wrappers.JaxToNumpy
3132
.. autoclass:: gymnasium.wrappers.JaxToTorch
3233
.. autoclass:: gymnasium.wrappers.NumpyToTorch

Diff for: docs/api/wrappers/table.md

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ wrapper in the page on the wrapper type
3434
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
3535
* - :class:`HumanRendering`
3636
- Allows human like rendering for environments that support "rgb_array" rendering.
37+
* - :class:`ToArray`
38+
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
3739
* - :class:`JaxToNumpy`
3840
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
3941
* - :class:`JaxToTorch`

Diff for: gymnasium/wrappers/to_array.py

+31-19
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
import functools
2020
import numbers
2121
from collections import abc
22-
from types import ModuleType
22+
from types import ModuleType, NoneType
2323
from typing import Any, Iterable, Mapping, SupportsFloat
2424

25+
import numpy as np
26+
from packaging.version import Version
27+
2528
import gymnasium as gym
2629
from gymnasium.core import RenderFrame, WrapperActType, WrapperObsType
2730
from gymnasium.error import DependencyNotInstalled
@@ -36,10 +39,12 @@
3639
)
3740

3841

42+
if Version(np.__version__) < Version("2.1.0"):
43+
raise DependencyNotInstalled("Array API functionality requires numpy >= 2.1.0")
44+
45+
3946
__all__ = ["ToArray", "to_xp"]
4047

41-
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
42-
_NoneType = type(None)
4348
Array = Any # TODO: Switch to ArrayAPI type once https://github.com/data-apis/array-api/pull/589 is merged
4449
Device = Any # TODO: Switch to ArrayAPI type if available
4550

@@ -85,7 +90,7 @@ def module_namespace(module: ModuleType) -> ModuleType:
8590

8691
@functools.singledispatch
8792
def to_xp(value: Any, xp: ModuleType, device: Device | None = None) -> Any:
88-
"""Converts a value into the specified xp module array type."""
93+
"""Convert a value into the specified xp module array type."""
8994
raise Exception(
9095
f"No known conversion for ({type(value)}) to xp module ({xp}) registered. Report as issue on github."
9196
)
@@ -103,40 +108,47 @@ def _number_to_xp(
103108
def _mapping_to_xp(
104109
value: Mapping[str, Any], xp: ModuleType, device: Device | None = None
105110
) -> Mapping[str, Any]:
106-
"""Converts a mapping of PyTorch Tensors into a Dictionary of Jax Array."""
111+
"""Convert a mapping of Arrays into a Dictionary of the specified xp module array type."""
107112
return type(value)(**{k: to_xp(v, xp, device) for k, v in value.items()})
108113

109114

110115
@to_xp.register(abc.Iterable)
111116
def _iterable_to_xp(
112117
value: Iterable[Any], xp: ModuleType, device: Device | None = None
113118
) -> Iterable[Any]:
114-
"""Converts an Iterable from PyTorch Tensors to an iterable of Jax Array."""
119+
"""Convert an Iterable from Arrays to an iterable of the specified xp module array type."""
115120
# There is currently no type for ArrayAPI compatible objects, so they fall through to this
116121
# function registered for any Iterable. If they are arrays, we can convert them directly.
117122
# We currently cannot pass the device to the from_dlpack function, since it is not supported
118123
# for some frameworks (see e.g. https://github.com/data-apis/array-api-compat/issues/204)
119124
if is_array_api_obj(value):
120-
try:
121-
x = xp.from_dlpack(value)
122-
return to_device(x, device) if device is not None else x
123-
except (RuntimeError, BufferError):
124-
# If dlpack fails (e.g. because the array is read-only for frameworks that do not
125-
# support it), we create a copy of the array that we own and then convert it.
126-
# TODO: The correct treatment of read-only arrays is currently not fully clear in the
127-
# Array API. Once ongoing discussions are resolved, we should update this code to remove
128-
# any fallbacks.
129-
value_namespace = array_namespace(value)
130-
value_copy = value_namespace.asarray(value, copy=True)
131-
return xp.asarray(value_copy, device=device)
125+
return _array_api_to_xp(value, xp, device)
132126
if hasattr(value, "_make"):
133127
# namedtuple - underline used to prevent potential name conflicts
134128
# noinspection PyProtectedMember
135129
return type(value)._make(to_xp(v, xp, device) for v in value)
136130
return type(value)(to_xp(v, xp, device) for v in value)
137131

138132

139-
@to_xp.register(_NoneType)
133+
def _array_api_to_xp(
134+
value: Array, xp: ModuleType, device: Device | None = None
135+
) -> Array:
136+
"""Convert an Array API compatible array to the specified xp module array type."""
137+
try:
138+
x = xp.from_dlpack(value)
139+
return to_device(x, device) if device is not None else x
140+
except (RuntimeError, BufferError):
141+
# If dlpack fails (e.g. because the array is read-only for frameworks that do not
142+
# support it), we create a copy of the array that we own and then convert it.
143+
# TODO: The correct treatment of read-only arrays is currently not fully clear in the
144+
# Array API. Once ongoing discussions are resolved, we should update this code to remove
145+
# any fallbacks.
146+
value_namespace = array_namespace(value)
147+
value_copy = value_namespace.asarray(value, copy=True)
148+
return xp.asarray(value_copy, device=device)
149+
150+
151+
@to_xp.register(NoneType)
140152
def _none_to_xp(value: None, xp: ModuleType, device: Device | None = None) -> None:
141153
"""Passes through None values."""
142154
return value

Diff for: pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ classifiers = [
2525
dependencies = [
2626
"numpy >=1.21.0",
2727
"cloudpickle >=1.2.0",
28-
"importlib-metadata >=4.8.0; python_version < '3.10'",
2928
"typing-extensions >=4.3.0",
3029
"farama-notifications >=0.0.1",
3130
]
@@ -88,6 +87,7 @@ testing = [
8887
"pytest >=7.1.3",
8988
"scipy >=1.7.3",
9089
"dill >=0.3.7",
90+
"array_api_extra >=0.7.0",
9191
]
9292

9393
[project.urls]
@@ -131,7 +131,7 @@ exclude = ["tests/**", "**/node_modules", "**/__pycache__"]
131131
strict = []
132132

133133
typeCheckingMode = "basic"
134-
pythonVersion = "3.8"
134+
pythonVersion = "3.10"
135135
pythonPlatform = "All"
136136
typeshedPath = "typeshed"
137137
enableTypeIgnoreComments = true

Diff for: tests/wrappers/test_to_array.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111

1212
array_api_compat = pytest.importorskip("array_api_compat")
13+
array_api_extra = pytest.importorskip("array_api_extra")
1314

1415
from array_api_compat import array_namespace, is_array_api_obj # noqa: E402
1516

@@ -51,12 +52,7 @@ def xp_data_equivalence(data_1, data_2) -> bool:
5152
xp_data_equivalence(o_1, o_2) for o_1, o_2 in zip(data_1, data_2)
5253
)
5354
elif is_array_api_obj(data_1):
54-
# Avoid a dependency on array-api-extra
55-
# Otherwise, we could use xpx.isclose(data_1, data_2, atol=0.00001).all()
56-
same_device = data_1.device == data_2.device
57-
a = np.asarray(data_1)
58-
b = np.asarray(data_2)
59-
return np.allclose(a, b, atol=0.00001) and same_device
55+
return array_api_extra.isclose(data_1, data_2, atol=0.00001).all()
6056
else:
6157
return data_1 == data_2
6258
else:

0 commit comments

Comments
 (0)