Skip to content

Add generic conversion wrapper between Array API compatible frameworks #1333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs-build-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: pip install -r docs/requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-build-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: pip install -r docs/requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-manual-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: pip install -r docs/requirements.txt
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/run-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['>=1.21,<2.0', '>=2.0']
python-version: ['3.10', '3.11', '3.12']
numpy-version: ['>=1.21,<2.0', '>=2.1']
steps:
- uses: actions/checkout@v4
- run: |
Expand All @@ -22,7 +22,6 @@ jobs:
- name: Run tests
run: docker run gymnasium-all-docker pytest tests/*
- name: Run doctests
if: ${{ matrix.python-version != '3.8' }}
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/

build-necessary:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ To install the base Gymnasium library, use `pip install gymnasium`

This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies.

We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python 3.9 has some more time until EoL - I think I'd be fine equipping this feature with some explicit warning/exception that says it doesn't work on 3.9, but I'd be very careful removing support for this version just for this feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would render the wrappers temporarily unusable for people that previously had access to them. What's your take on this @pseudo-rnd-thoughts ? Also, EOL is in October, so that's not too far off


## API

Expand Down
1 change: 1 addition & 0 deletions docs/api/wrappers/misc_wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ title: Misc Wrappers
## Data Conversion Wrappers

```{eval-rst}
.. autoclass:: gymnasium.wrappers.ToArray
.. autoclass:: gymnasium.wrappers.JaxToNumpy
.. autoclass:: gymnasium.wrappers.JaxToTorch
.. autoclass:: gymnasium.wrappers.NumpyToTorch
Expand Down
2 changes: 2 additions & 0 deletions docs/api/wrappers/table.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ wrapper in the page on the wrapper type
- Converts an image observation computed by ``reset`` and ``step`` from RGB to Grayscale.
* - :class:`HumanRendering`
- Allows human like rendering for environments that support "rgb_array" rendering.
* - :class:`ToArray`
- Wraps an environment based on any Array API compatible framework, e.g. torch, jax, numpy, such that it can be interacted with any other Array API compatible framework.
* - :class:`JaxToNumpy`
- Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
* - :class:`JaxToTorch`
Expand Down
150 changes: 6 additions & 144 deletions gymnasium/wrappers/jax_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@
from __future__ import annotations

import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat

import numpy as np

import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.wrappers.to_array import ToArray, to_xp


try:
import jax
import jax.numpy as jnp
except ImportError:
raise DependencyNotInstalled(
Expand All @@ -24,104 +21,13 @@

__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]

# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
_NoneType = type(None)

jax_to_numpy = functools.partial(to_xp, xp=np)

@functools.singledispatch
def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax Array."""
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)


@numpy_to_jax.register(numbers.Number)
def _number_to_jax(
value: numbers.Number,
) -> jax.Array:
"""Converts a number (int, float, etc.) to a Jax Array."""
assert jnp is not None
return jnp.array(value)


@numpy_to_jax.register(np.ndarray)
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array:
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled)."""
assert jnp is not None
return jnp.array(value, dtype=value.dtype)


@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax Array."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})


@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(numpy_to_jax(v) for v in value)
else:
return type(value)(numpy_to_jax(v) for v in value)


@numpy_to_jax.register(_NoneType)
def _none_numpy_to_jax(value: None) -> None:
"""Passes through None values."""
return value


@functools.singledispatch
def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array."""
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)


@jax_to_numpy.register(jax.Array)
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray:
"""Converts a Jax Array to a numpy array."""
return np.array(value)

numpy_to_jax = functools.partial(to_xp, xp=jnp)

@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jax.Array | Any],
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax Array to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})


@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(jax_to_numpy(v) for v in value)
else:
return type(value)(jax_to_numpy(v) for v in value)


@jax_to_numpy.register(_NoneType)
def _none_jax_to_numpy(value: None) -> None:
"""Passes through None values."""
return value


class JaxToNumpy(
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
class JaxToNumpy(ToArray):
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.

Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
Expand Down Expand Up @@ -163,48 +69,4 @@ def __init__(self, env: gym.Env[ObsType, ActType]):
raise DependencyNotInstalled(
'Jax is not installed, run `pip install "gymnasium[jax]"`'
)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Transforms the action to a jax array .

Args:
action: the action to perform as a numpy array

Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_action = numpy_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)

return (
jax_to_numpy(obs),
float(reward),
bool(terminated),
bool(truncated),
jax_to_numpy(info),
)

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.

Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.

Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)

return jax_to_numpy(self.env.reset(seed=seed, options=options))

def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a numpy array."""
return jax_to_numpy(self.env.render())
super().__init__(env=env, env_xp=jnp, target_xp=np)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're doing an extremely cursed multiple inheritance here, so to maintain some sanity here - I think it'd be better here to specify the parent class instead of super()? Since the arguments are very specific, and super() can act unpredictably with multiple inheritance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment about pickle. The multiple inheritance issue would go away completely. Also, all inheritance can be moved into the ToArray wrapper, so that all special wrappers only inherit from ToArray.

Loading
Loading