-
-
Notifications
You must be signed in to change notification settings - Fork 982
Add generic conversion wrapper between Array API compatible frameworks #1333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…. Fix conversions. Fix all tests
@pseudo-rnd-thoughts Any thoughts on how we might resolve the issue with Python 3.9? I think this wrapper makes a lot of sense to provide a consistent way of switching between array frameworks, but obviously the CI cannot fail for Python 3.9. As far as I can see, the gymnasium team would need to make a call on whether or not it's okay to drop 3.9 support for jax and torch. Not sure if lower python versions are required for compatibility, given that the addition of torch and jax in gymnasium are fairly new |
@amacati Thanks for the PR, sorry I must have forgotten to review it. Yes, this looks great. |
No worries, I just opened this. The dependency issue is not immediately obvious. The array API gets extended in yearly editions by the consortium once features are stabilized and deemed essential for all frameworks. The wrapper makes use of array API features that have been stabilized in the 2023 edition. However, numpy only supports those from numpy 2.1 onwards (see the release notes, first highlight at the very top). Since I deem numpy the core array API framework for gymnasium, I would not want to ship conversion wrappers that are broken for numpy. It should at least be communicated clearly, and ideally we could specify it in the dependencies. In short: yes, we can allow numpy<2.1, but this would render |
As a side note, is there any way to modify the observation and action spaces so that they automatically return arrays in the correct framework? This is something I have encountered quite frequently. It's rather unintuitive that the spaces return numpy arrays for, e.g., a jax environment. After all, a "true" observation would be a jax array. This might warrant its own PR, just came to my mind while writing the wrapper. |
Ahh, ok. With the release v1.1, then I think we can move to supporting python 3.10+
Not that it automatically returns arrays in the correct frameskip but we added a feature for Gymnasium/gymnasium/utils/env_checker.py Line 388 in ba0fa45
|
Oh no, I didn't mean the wrapper around the environment. I meant the actual I.e. type(env.observation_space.sample()) # >>> torch.tensor
torch.tensor([1.0, 2.0]) in env.observation_space # >>> True Regarding gymnasium 1.1: That means we are waiting for the next release, leave this PR open until then, and merge it with a |
Ohh I see what you mean. In my head, I would want to modify the
I meant that with Gymnasium v1.1 fixing bugs in Gymnasium v1.0 I'm happy for v1.2 (next release) to significantly reduce the python version from 3.8+ to 3.10+ |
Yes, makes perfect sense. Though we'd also have to redefine a bunch of other stuff to make I will keep this in mind and wait for this PR to land before working on it.
Alright, looking forward to the release of 1.1 then to pick this up again! |
Ahh yes, this is why I hadn't done this previously, because you would need to implement Also, we already cut v1.1 two weeks ago but with minimal publicity (https://github.com/Farama-Foundation/Gymnasium/releases/tag/v1.1.0) |
If 1.1 is out, should I update the Python requirement for the whole package to 3.10 to move forward? Or do you want to do this in a separate PR? |
Update to 3.10 here |
I have increased the version requirements, addressed a bug in the vectorized version of ToArray, and added tests for the vector wrapper. Should be ready for a first review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. I like that this simplifies implementations.
My main question is if particular implementations, e.g., jax, wants to implement a new conversion functions is that possible?
Is there any difference between the NumpyToTorch
and JaxToTorch
? Could we just have a ToNumpy
, ToJax
, and ToTorch
wrappers?
Also could you add this to the documentation (https://github.com/Farama-Foundation/Gymnasium/blob/main/docs/api/wrappers/misc_wrappers.md and https://github.com/Farama-Foundation/Gymnasium/blob/main/docs/api/wrappers/table.md)
At the moment: No. However, we can easily integrate that in the future if we want to. I'm currently working on a similar PR for scipy where we already want to support multiple implementations. The solution there is to have a We can then either register custom backends on our side, or allow users to define their own, specialized versions of conversions. To give you a rough idea, this would look approximately like so # ...
backend = {}
@to_xp.register(abc.Iterable)
def _iterable_to_xp(
value: Iterable[Any], xp: ModuleType, device: Device | None = None
) -> Iterable[Any]:
if is_array_api_obj(value):
return backend.get(xp, _array_api_to_xp)(value, xp, device)
...
def _array_api_to_xp(
value: Array, xp: ModuleType, device: Device | None = None
) -> Array:
try:
x = xp.from_dlpack(value)
return to_device(x, device) if device is not None else x
except (RuntimeError, BufferError):
value_namespace = array_namespace(value)
value_copy = value_namespace.asarray(value, copy=True)
return xp.asarray(value_copy, device=device)
def special_to_jax_conversion(value, xp, device):
return jax.numpy.from_dlpack(value, device=device)
backend[jax.numpy] = special_to_jax_conversion
# ... The advantage of this approach is that you can let other packages define efficient conversions by exposing some official
No, there should not be any difference. I thought that you want to keep them around for backward compatibility. Also, the
from functools import partial
import numpy as np
from gymnasium.wrappers.to_array import ToArray
ToNumpy = partial(ToArray, target_xp=np) |
Add ToArray wrapper to docs Add array-api-extra for reliable testing Fix smaller mistakes in the docs Add NumPy version check for Array API wrappers
I have addressed most of your comments. In addition, I noticed that some tests failed with jax on systems with a GPU because the test comparison function for Array API objects did not cover some edge cases. Therefore, I have added However, since this one is exclusively used for testing and part of the Array API projects which we depend on anyways, I feel like this is worth including. With it, all tests pass also on systems with GPUs. |
@amacati This looks good to me, but I want @RedTachyon to look at the PR before merging as they were one of the authors of the original data conversion wrappers. Hopefully they will look at it this weekend or next week |
@amacati i might be wrong on this but i realised that an important aspect of the NumpyToTorch wrapper is not just that the wrapper returns torch observations etc, but that you can pass torch actions which are converted to NumPy for the environment. I don't believe the implementation does that. This is why the wrappers are called NumpyToTorch etc not just ToTorch |
@pseudo-rnd-thoughts This is why there is an
It definitely does! The old tests for Gymnasium/gymnasium/wrappers/to_array.py Line 222 in 05d6d82
Taking the I wanted to automate the environment framework detection, but there is no reliable way without resetting the env to get a sample of Actually, this was also a motivation for having observation spaces that use the "correct" framework. That would allow us to use auto-detection through |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for the PR, I really like the idea of this fairly generic array-to-array conversion. I left a bunch of comments, some important, some nitpicks.
As a caveat: I didn't actually run the new code, I'm trusting the tests on this. The review isn't super comprehensive, but hopefully it'll help catch some potential issues before they actually come up.
@@ -31,7 +31,7 @@ To install the base Gymnasium library, use `pip install gymnasium` | |||
|
|||
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies. | |||
|
|||
We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. | |||
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python 3.9 has some more time until EoL - I think I'd be fine equipping this feature with some explicit warning/exception that says it doesn't work on 3.9, but I'd be very careful removing support for this version just for this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would render the wrappers temporarily unusable for people that previously had access to them. What's your take on this @pseudo-rnd-thoughts ? Also, EOL is in October, so that's not too far off
gymnasium/wrappers/jax_to_numpy.py
Outdated
def render(self) -> RenderFrame | list[RenderFrame] | None: | ||
"""Returns the rendered frames as a numpy array.""" | ||
return jax_to_numpy(self.env.render()) | ||
gym.utils.ezpickle.EzPickle.__init__(self, env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need EzPickle here? Plus does this actually work with arbitrary envs?
If I remember right, the idea behind EzPickle is that you can "fake pickle" an environment, so that later you can reconstruct it in its original form, without necessarily preserving the actual state of the environment. This wrapper doesn't have any arguments, so I'd guess this isn't necessary. But please do check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, but your comment made me realize we need even more than that. We cannot pickle the ToArray
wrapper because it contains references to two modules, self._env_xp
and self._target_xp
. These will prevent pickle from working, because you cannot pickle whole modules. What's more, using them as arguments will also prevent EzPickle
from working correctly. I see two ways around that:
- Using strings instead of modules to create the wrapper (not a fan)
- Creating custom
__setstate__
and__getstate__
functions (should be fairly easy)
The reason I dislike the first option is that using the modules as input will a) force people to only use modules they have actually installed and b) prevent a magic conversion from strings to modules.
What's your opinion on that?
"""Returns the rendered frames as a numpy array.""" | ||
return jax_to_numpy(self.env.render()) | ||
gym.utils.ezpickle.EzPickle.__init__(self, env) | ||
super().__init__(env=env, env_xp=jnp, target_xp=np) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're doing an extremely cursed multiple inheritance here, so to maintain some sanity here - I think it'd be better here to specify the parent class instead of super()
? Since the arguments are very specific, and super()
can act unpredictably with multiple inheritance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comment about pickle. The multiple inheritance issue would go away completely. Also, all inheritance can be moved into the ToArray
wrapper, so that all special wrappers only inherit from ToArray
.
gymnasium/wrappers/jax_to_torch.py
Outdated
gym.utils.RecordConstructorArgs.__init__(self, device=device) | ||
gym.Wrapper.__init__(self, env) | ||
gym.utils.RecordConstructorArgs.__init__(self) | ||
gym.utils.ezpickle.EzPickle.__init__(self, env, device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comments in jax_to_numpy
|
||
# TODO: Device was part of the public API, but should be removed in favor of _env_device and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this meant to be kept for a future PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good catch. This was meant as a comment for reviews. JaxToTorch
previously had a device
attribute, which would be removed with the new ToArray
wrapper in favor of _env_device
and _target_device
. These are private however, whereas device
was public. Removing it could break existing code. The question was if you are willing to make that breaking change.
|
||
return jax_to_numpy(self.env.reset(seed=seed, options=options)) | ||
gym.utils.RecordConstructorArgs.__init__(self) | ||
gym.utils.ezpickle.EzPickle.__init__(self, env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See jax_to_numpy.py
@@ -45,48 +48,8 @@ def __init__(self, env: VectorEnv, device: Device | None = None): | |||
env: The NumPy-based vector environment to wrap | |||
device: The device the torch Tensors should be moved to | |||
""" | |||
super().__init__(env) | |||
gym.utils.RecordConstructorArgs.__init__(self) | |||
gym.utils.ezpickle.EzPickle.__init__(self, env, device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See jax_to_numpy.py (side note: I just noticed that wrappers pointing to torch do actually have the extra argument of the device - still, I'm not sure if it's the right pattern for conversion wrappers)
Returns: | ||
xp-based observations and info | ||
""" | ||
if options: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the non-vectorized to_array.py
Returns: | ||
xp-based observations and info | ||
""" | ||
if options: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where things might get a bit dicey. Options are meant to be relatively free-form, and I don't think we should force any specific structure on this object. In particular, values might be strings, or any other exotic objects. It's not obvious to me whether they should be automatically converted. Same thing applies to info dicts. Does the current to_xp
function handle arbitrary objects that might not be easily convertible?
I'd consider adding an extra parameter to this (and the vector version) wrapper, defining the behavior of options/infos/anything else that might be questionable. Else this might cause a lot of weird interactions.
As an example: an option is an int which is then converted to a string, or maybe even directly used as a key in a dictionary. Automatically converting options to arrays might would have some hard to predict consequences.
(side note: I realize that this issue was already present, but might as well fix it/think about it if we're revamping the whole thing anyways)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not. Currently, anything that can't be converted raises an error. There is also no way to specify that e.g. a numpy array in the info dict should not be converted. The list of objects we can convert is limited to:
- Numbers
- Mappings
- Iterables (-> Arrays)
- None
So even a string should raise an error. We could ofc also return objects that are not supported as they are. Alternatively, one could have convert_info
and convert_options
arguments, but I am also in favor of keeping all init args limited and simple.
@@ -8,16 +8,14 @@ build-backend = "setuptools.build_meta" | |||
name = "gymnasium" | |||
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." | |||
readme = "README.md" | |||
requires-python = ">= 3.8" | |||
requires-python = ">= 3.10" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comments in another file about bumping the version requirements
The latest changes address some of the concerns with pickle, though I'm not 100% satisfied. Any thoughts on what a good solution would look like @RedTachyon ? |
How willing are you to make breaking changes @RedTachyon ? If we make the arguments to |
Description
Motivation
Recently, environments implemented in frameworks such as
torch
andjax
have gained importance as they provide a convenient way to run parallelized environments on GPUs and compile parts of the logic. To enable the mixing of environments written in one framework and training in another framework,gymnasium
provides wrappers that automatically convert from one framework to the other and back. However, only a few combinations are currently implemented, namely JaxToNumpy, JaxToTorch and NumpyToTorch. Other combinations, e.g. an environment implemented intorch
and converting tojax
arrays, are missing.Proposal
This PR proposes to add a generic
ToArray
wrapper from one array framework to the other. It leverages the Python array API standard (see here), which defines common operations that any framework has to implement in order to be compatible. Frameworks that are currently not fully compatible are wrapped in a compatibility layer (see array-api-compat). With this design, the wrapper can handle any combination of supported libraries. Notably, this includes previously supported frameworks such asnumpy
andtorch
, as well as new ones such ascupy
.Advantages
One major advantage is that gymnasium can support conversions between arbitrary array API compatible frameworks without additional efforts on our side. As soon as a library is compatible with the standard, it automatically works with the new wrapper.
Furthermore, the new
to_xp
conversion function to an arbitrary, array API compatible frameworkxp
only has to be defined once. This significantly reduces code duplication (compare the near-identical conversions in jax_to_torch.py and numpy_to_torch.py). This change is fully backwards compatible. We can retain the old wrappers as special cases of theToArray
wrapper, and keep the conversion functions.As an example,
numpy_to_torch
can now be written aswhich reduces it to a special case of the generic
to_xp
function while maintaining backwards compatibility.The new wrapper also allows more control over the array devices. Up to now, it was e.g. not possible to specify a device for a jax environment in
JaxToTorch
.Disadvantages
While many libraries are actively working towards supporting the array API, some have not achieved full compatibility yet, e.g.
torch
. The array API consortium provides a compatibility layer for these cases. In order to support the core frameworks, this introduces a new dependency onarray-api-compat
.In addition, the wrapper requires frameworks to support at least the 2023 version of the standard. NumPy added support for this version in numpy 2.1. Therefore, installing
gymnasium
with its torch or jax dependencies will raise the required numpy version to 2.1.Note: NumPy 2.1 only supports Python>=3.10! Consequently, this also raises the Python requirement for
gymnasium[jax,torch]
!Example
Convert from a
torch
environment tojax
using theToArray
wrapper.Type of change
Dependency changes
As explained in the previous sections, this PR requires the following version updates:
numpy>=2.1
forgymnasium[torch,jax,array-api]
python>=3.10
caused bynumpy>=2.1
array-api-compat
forgymnasium[torch,jax,array-api]
Blockers
The current PR will fail the test suite on Python 3.9. If we are fine with supporting
torch
,jax
etc. for python 3.10 and onward only, this PR can move ahead. Otherwise, this PR is blocked until we officially drop support for 3.9.Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up)