-
-
Notifications
You must be signed in to change notification settings - Fork 988
Add generic conversion wrapper between Array API compatible frameworks #1333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
a04df9a
82f71d5
1ec4e29
f801c7a
95f9988
f2f0b93
46d973e
d3ba730
5d6da0f
22da073
75f0c5c
06e6f08
50acf9f
a40f1fd
9556549
05d6d82
2b1dc7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,19 +3,16 @@ | |
from __future__ import annotations | ||
|
||
import functools | ||
import numbers | ||
from collections import abc | ||
from typing import Any, Iterable, Mapping, SupportsFloat | ||
|
||
import numpy as np | ||
|
||
import gymnasium as gym | ||
from gymnasium.core import ActType, ObsType, RenderFrame, WrapperActType, WrapperObsType | ||
from gymnasium.core import ActType, ObsType | ||
from gymnasium.error import DependencyNotInstalled | ||
from gymnasium.wrappers.to_array import ToArray, to_xp | ||
|
||
|
||
try: | ||
import jax | ||
import jax.numpy as jnp | ||
except ImportError: | ||
raise DependencyNotInstalled( | ||
|
@@ -24,104 +21,13 @@ | |
|
||
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"] | ||
|
||
# The NoneType is not defined in Python 3.9. Remove when the minimal version is bumped to >=3.10 | ||
_NoneType = type(None) | ||
|
||
jax_to_numpy = functools.partial(to_xp, xp=np) | ||
|
||
@functools.singledispatch | ||
def numpy_to_jax(value: Any) -> Any: | ||
"""Converts a value to a Jax Array.""" | ||
raise Exception( | ||
f"No known conversion for Numpy type ({type(value)}) to Jax registered. Report as issue on github." | ||
) | ||
|
||
|
||
@numpy_to_jax.register(numbers.Number) | ||
def _number_to_jax( | ||
value: numbers.Number, | ||
) -> jax.Array: | ||
"""Converts a number (int, float, etc.) to a Jax Array.""" | ||
assert jnp is not None | ||
return jnp.array(value) | ||
|
||
|
||
@numpy_to_jax.register(np.ndarray) | ||
def _numpy_array_to_jax(value: np.ndarray) -> jax.Array: | ||
"""Converts a NumPy Array to a Jax Array with the same dtype (excluding float64 without being enabled).""" | ||
assert jnp is not None | ||
return jnp.array(value, dtype=value.dtype) | ||
|
||
|
||
@numpy_to_jax.register(abc.Mapping) | ||
def _mapping_numpy_to_jax(value: Mapping[str, Any]) -> Mapping[str, Any]: | ||
"""Converts a dictionary of numpy arrays to a mapping of Jax Array.""" | ||
return type(value)(**{k: numpy_to_jax(v) for k, v in value.items()}) | ||
|
||
|
||
@numpy_to_jax.register(abc.Iterable) | ||
def _iterable_numpy_to_jax( | ||
value: Iterable[np.ndarray | Any], | ||
) -> Iterable[jax.Array | Any]: | ||
"""Converts an Iterable from Numpy Arrays to an iterable of Jax Array.""" | ||
if hasattr(value, "_make"): | ||
# namedtuple - underline used to prevent potential name conflicts | ||
# noinspection PyProtectedMember | ||
return type(value)._make(numpy_to_jax(v) for v in value) | ||
else: | ||
return type(value)(numpy_to_jax(v) for v in value) | ||
|
||
|
||
@numpy_to_jax.register(_NoneType) | ||
def _none_numpy_to_jax(value: None) -> None: | ||
"""Passes through None values.""" | ||
return value | ||
|
||
|
||
@functools.singledispatch | ||
def jax_to_numpy(value: Any) -> Any: | ||
"""Converts a value to a numpy array.""" | ||
raise Exception( | ||
f"No known conversion for Jax type ({type(value)}) to NumPy registered. Report as issue on github." | ||
) | ||
|
||
|
||
@jax_to_numpy.register(jax.Array) | ||
def _devicearray_jax_to_numpy(value: jax.Array) -> np.ndarray: | ||
"""Converts a Jax Array to a numpy array.""" | ||
return np.array(value) | ||
|
||
numpy_to_jax = functools.partial(to_xp, xp=jnp) | ||
|
||
@jax_to_numpy.register(abc.Mapping) | ||
def _mapping_jax_to_numpy( | ||
value: Mapping[str, jax.Array | Any], | ||
) -> Mapping[str, np.ndarray | Any]: | ||
"""Converts a dictionary of Jax Array to a mapping of numpy arrays.""" | ||
return type(value)(**{k: jax_to_numpy(v) for k, v in value.items()}) | ||
|
||
|
||
@jax_to_numpy.register(abc.Iterable) | ||
def _iterable_jax_to_numpy( | ||
value: Iterable[np.ndarray | Any], | ||
) -> Iterable[jax.Array | Any]: | ||
"""Converts an Iterable from Numpy arrays to an iterable of Jax Array.""" | ||
if hasattr(value, "_make"): | ||
# namedtuple - underline used to prevent potential name conflicts | ||
# noinspection PyProtectedMember | ||
return type(value)._make(jax_to_numpy(v) for v in value) | ||
else: | ||
return type(value)(jax_to_numpy(v) for v in value) | ||
|
||
|
||
@jax_to_numpy.register(_NoneType) | ||
def _none_jax_to_numpy(value: None) -> None: | ||
"""Passes through None values.""" | ||
return value | ||
|
||
|
||
class JaxToNumpy( | ||
gym.Wrapper[WrapperObsType, WrapperActType, ObsType, ActType], | ||
gym.utils.RecordConstructorArgs, | ||
): | ||
class JaxToNumpy(ToArray, gym.utils.RecordConstructorArgs, gym.utils.ezpickle.EzPickle): | ||
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays. | ||
|
||
Actions must be provided as numpy arrays and observations will be returned as numpy arrays. | ||
|
@@ -164,47 +70,5 @@ def __init__(self, env: gym.Env[ObsType, ActType]): | |
'Jax is not installed, run `pip install "gymnasium[jax]"`' | ||
) | ||
gym.utils.RecordConstructorArgs.__init__(self) | ||
gym.Wrapper.__init__(self, env) | ||
|
||
def step( | ||
self, action: WrapperActType | ||
) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: | ||
"""Transforms the action to a jax array . | ||
|
||
Args: | ||
action: the action to perform as a numpy array | ||
|
||
Returns: | ||
A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info. | ||
""" | ||
jax_action = numpy_to_jax(action) | ||
obs, reward, terminated, truncated, info = self.env.step(jax_action) | ||
|
||
return ( | ||
jax_to_numpy(obs), | ||
float(reward), | ||
bool(terminated), | ||
bool(truncated), | ||
jax_to_numpy(info), | ||
) | ||
|
||
def reset( | ||
self, *, seed: int | None = None, options: dict[str, Any] | None = None | ||
) -> tuple[WrapperObsType, dict[str, Any]]: | ||
"""Resets the environment returning numpy-based observation and info. | ||
|
||
Args: | ||
seed: The seed for resetting the environment | ||
options: The options for resetting the environment, these are converted to jax arrays. | ||
|
||
Returns: | ||
Numpy-based observations and info | ||
""" | ||
if options: | ||
options = numpy_to_jax(options) | ||
|
||
return jax_to_numpy(self.env.reset(seed=seed, options=options)) | ||
|
||
def render(self) -> RenderFrame | list[RenderFrame] | None: | ||
"""Returns the rendered frames as a numpy array.""" | ||
return jax_to_numpy(self.env.render()) | ||
gym.utils.ezpickle.EzPickle.__init__(self, env) | ||
super().__init__(env=env, env_xp=jnp, target_xp=np) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're doing an extremely cursed multiple inheritance here, so to maintain some sanity here - I think it'd be better here to specify the parent class instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the comment about pickle. The multiple inheritance issue would go away completely. Also, all inheritance can be moved into the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need EzPickle here? Plus does this actually work with arbitrary envs?
If I remember right, the idea behind EzPickle is that you can "fake pickle" an environment, so that later you can reconstruct it in its original form, without necessarily preserving the actual state of the environment. This wrapper doesn't have any arguments, so I'd guess this isn't necessary. But please do check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, but your comment made me realize we need even more than that. We cannot pickle the
ToArray
wrapper because it contains references to two modules,self._env_xp
andself._target_xp
. These will prevent pickle from working, because you cannot pickle whole modules. What's more, using them as arguments will also preventEzPickle
from working correctly. I see two ways around that:__setstate__
and__getstate__
functions (should be fairly easy)The reason I dislike the first option is that using the modules as input will a) force people to only use modules they have actually installed and b) prevent a magic conversion from strings to modules.
What's your opinion on that?