Skip to content

Add generic conversion wrapper between Array API compatible frameworks #1333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs-build-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: pip install -r docs/requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-build-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: pip install -r docs/requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-manual-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.10'

- name: Install dependencies
run: pip install -r docs/requirements.txt
Expand Down
5 changes: 2 additions & 3 deletions .github/workflows/run-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['>=1.21,<2.0', '>=2.0']
python-version: ['3.10', '3.11', '3.12']
numpy-version: ['>=1.21,<2.0', '>=2.1']
steps:
- uses: actions/checkout@v4
- run: |
Expand All @@ -22,7 +22,6 @@ jobs:
- name: Run tests
run: docker run gymnasium-all-docker pytest tests/*
- name: Run doctests
if: ${{ matrix.python-version != '3.8' }}
run: docker run gymnasium-all-docker pytest --doctest-modules gymnasium/

build-necessary:
Expand Down
150 changes: 7 additions & 143 deletions gymnasium/wrappers/jax_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@
from __future__ import annotations

import functools
import numbers
from collections import abc
from typing import Any, Iterable, Mapping, SupportsFloat

import numpy as np

import gymnasium as gym
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.wrappers.to_array import ToArray, to_xp


try:
import jax
import jax.numpy as jnp
except ImportError:
raise DependencyNotInstalled(
Expand All @@ -24,104 +21,13 @@

__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]

# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10
_NoneType = type(None)

jax_to_numpy = functools.partial(to_xp, xp=np)

@functools.singledispatch
def numpy_to_jax(value: Any) -> Any:
"""Converts a value to a Jax Array."""
raise Exception(
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github."
)


@numpy_to_jax.register(numbers.Number)
def _number_to_jax(
value: numbers.Number,
) -> jax.Array:
"""Converts a number (int, float, etc.) to a Jax Array."""
assert jnp is not None
return jnp.array(value)


@numpy_to_jax.register(np.ndarray)
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array:
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled)."""
assert jnp is not None
return jnp.array(value, dtype=value.dtype)


@numpy_to_jax.register(abc.Mapping)
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]:
"""Converts a dictionary of numpy arrays to a mapping of Jax Array."""
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()})


@numpy_to_jax.register(abc.Iterable)
def _iterable_numpy_to_jax(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(numpy_to_jax(v) for v in value)
else:
return type(value)(numpy_to_jax(v) for v in value)


@numpy_to_jax.register(_NoneType)
def _none_numpy_to_jax(value: None) -> None:
"""Passes through None values."""
return value


@functools.singledispatch
def jax_to_numpy(value: Any) -> Any:
"""Converts a value to a numpy array."""
raise Exception(
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github."
)


@jax_to_numpy.register(jax.Array)
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray:
"""Converts a Jax Array to a numpy array."""
return np.array(value)

numpy_to_jax = functools.partial(to_xp, xp=jnp)

@jax_to_numpy.register(abc.Mapping)
def _mapping_jax_to_numpy(
value: Mapping[str, jax.Array | Any],
) -> Mapping[str, np.ndarray | Any]:
"""Converts a dictionary of Jax Array to a mapping of numpy arrays."""
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()})


@jax_to_numpy.register(abc.Iterable)
def _iterable_jax_to_numpy(
value: Iterable[np.ndarray | Any],
) -> Iterable[jax.Array | Any]:
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array."""
if hasattr(value, "_make"):
# namedtuple - underline used to prevent potential name conflicts
# noinspection PyProtectedMember
return type(value)._make(jax_to_numpy(v) for v in value)
else:
return type(value)(jax_to_numpy(v) for v in value)


@jax_to_numpy.register(_NoneType)
def _none_jax_to_numpy(value: None) -> None:
"""Passes through None values."""
return value


class JaxToNumpy(
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType],
gym.utils.RecordConstructorArgs,
):
class JaxToNumpy(ToArray, gym.utils.RecordConstructorArgs, gym.utils.ezpickle.EzPickle):
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.

Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
Expand Down Expand Up @@ -164,47 +70,5 @@ def __init__(self, env: gym.Env[ObsType, ActType]):
'Jax is not installed, run `pip install "gymnasium[jax]"`'
)
gym.utils.RecordConstructorArgs.__init__(self)
gym.Wrapper.__init__(self, env)

def step(
self, action: WrapperActType
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]:
"""Transforms the action to a jax array .

Args:
action: the action to perform as a numpy array

Returns:
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.
"""
jax_action = numpy_to_jax(action)
obs, reward, terminated, truncated, info = self.env.step(jax_action)

return (
jax_to_numpy(obs),
float(reward),
bool(terminated),
bool(truncated),
jax_to_numpy(info),
)

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[WrapperObsType, dict[str, Any]]:
"""Resets the environment returning numpy-based observation and info.

Args:
seed: The seed for resetting the environment
options: The options for resetting the environment, these are converted to jax arrays.

Returns:
Numpy-based observations and info
"""
if options:
options = numpy_to_jax(options)

return jax_to_numpy(self.env.reset(seed=seed, options=options))

def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a numpy array."""
return jax_to_numpy(self.env.render())
gym.utils.ezpickle.EzPickle.__init__(self, env)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need EzPickle here? Plus does this actually work with arbitrary envs?

If I remember right, the idea behind EzPickle is that you can "fake pickle" an environment, so that later you can reconstruct it in its original form, without necessarily preserving the actual state of the environment. This wrapper doesn't have any arguments, so I'd guess this isn't necessary. But please do check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, but your comment made me realize we need even more than that. We cannot pickle the ToArray wrapper because it contains references to two modules, self._env_xp and self._target_xp. These will prevent pickle from working, because you cannot pickle whole modules. What's more, using them as arguments will also prevent EzPickle from working correctly. I see two ways around that:

  • Using strings instead of modules to create the wrapper (not a fan)
  • Creating custom __setstate__ and __getstate__ functions (should be fairly easy)

The reason I dislike the first option is that using the modules as input will a) force people to only use modules they have actually installed and b) prevent a magic conversion from strings to modules.

What's your opinion on that?

super().__init__(env=env, env_xp=jnp, target_xp=np)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're doing an extremely cursed multiple inheritance here, so to maintain some sanity here - I think it'd be better here to specify the parent class instead of super()? Since the arguments are very specific, and super() can act unpredictably with multiple inheritance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment about pickle. The multiple inheritance issue would go away completely. Also, all inheritance can be moved into the ToArray wrapper, so that all special wrappers only inherit from ToArray.

Loading
Loading