-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add generic conversion wrapper between Array API compatible frameworks #1333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@pseudo-rnd-thoughts Any thoughts on how we might resolve the issue with Python 3.9? I think this wrapper makes a lot of sense to provide a consistent way of switching between array frameworks, but obviously the CI cannot fail for Python 3.9. As far as I can see, the gymnasium team would need to make a call on whether or not it's okay to drop 3.9 support for jax and torch. Not sure if lower python versions are required for compatibility, given that the addition of torch and jax in gymnasium are fairly new |
@amacati Thanks for the PR, sorry I must have forgotten to review it. Yes, this looks great. |
No worries, I just opened this. The dependency issue is not immediately obvious. The array API gets extended in yearly editions by the consortium once features are stabilized and deemed essential for all frameworks. The wrapper makes use of array API features that have been stabilized in the 2023 edition. However, numpy only supports those from numpy 2.1 onwards (see the release notes, first highlight at the very top). Since I deem numpy the core array API framework for gymnasium, I would not want to ship conversion wrappers that are broken for numpy. It should at least be communicated clearly, and ideally we could specify it in the dependencies. In short: yes, we can allow numpy<2.1, but this would render |
As a side note, is there any way to modify the observation and action spaces so that they automatically return arrays in the correct framework? This is something I have encountered quite frequently. It's rather unintuitive that the spaces return numpy arrays for, e.g., a jax environment. After all, a "true" observation would be a jax array. This might warrant its own PR, just came to my mind while writing the wrapper. |
Ahh, ok. With the release v1.1, then I think we can move to supporting python 3.10+
Not that it automatically returns arrays in the correct frameskip but we added a feature for Gymnasium/gymnasium/utils/env_checker.py Line 388 in ba0fa45
|
Oh no, I didn't mean the wrapper around the environment. I meant the actual I.e. type(env.observation_space.sample()) # >>> torch.tensor
torch.tensor([1.0, 2.0]) in env.observation_space # >>> True Regarding gymnasium 1.1: That means we are waiting for the next release, leave this PR open until then, and merge it with a |
Ohh I see what you mean. In my head, I would want to modify the
I meant that with Gymnasium v1.1 fixing bugs in Gymnasium v1.0 I'm happy for v1.2 (next release) to significantly reduce the python version from 3.8+ to 3.10+ |
Yes, makes perfect sense. Though we'd also have to redefine a bunch of other stuff to make I will keep this in mind and wait for this PR to land before working on it.
Alright, looking forward to the release of 1.1 then to pick this up again! |
Ahh yes, this is why I hadn't done this previously, because you would need to implement Also, we already cut v1.1 two weeks ago but with minimal publicity (https://github.com/Farama-Foundation/Gymnasium/releases/tag/v1.1.0) |
If 1.1 is out, should I update the Python requirement for the whole package to 3.10 to move forward? Or do you want to do this in a separate PR? |
Update to 3.10 here |
I have increased the version requirements, addressed a bug in the vectorized version of ToArray, and added tests for the vector wrapper. Should be ready for a first review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. I like that this simplifies implementations.
My main question is if particular implementations, e.g., jax, wants to implement a new conversion functions is that possible?
Is there any difference between the NumpyToTorch
and JaxToTorch
? Could we just have a ToNumpy
, ToJax
, and ToTorch
wrappers?
Also could you add this to the documentation (https://github.com/Farama-Foundation/Gymnasium/blob/main/docs/api/wrappers/misc_wrappers.md and https://github.com/Farama-Foundation/Gymnasium/blob/main/docs/api/wrappers/table.md)
At the moment: No. However, we can easily integrate that in the future if we want to. I'm currently working on a similar PR for scipy where we already want to support multiple implementations. The solution there is to have a We can then either register custom backends on our side, or allow users to define their own, specialized versions of conversions. To give you a rough idea, this would look approximately like so # ...
backend = {}
@to_xp.register(abc.Iterable)
def _iterable_to_xp(
value: Iterable[Any], xp: ModuleType, device: Device | None = None
) -> Iterable[Any]:
if is_array_api_obj(value):
return backend.get(xp, _array_api_to_xp)(value, xp, device)
...
def _array_api_to_xp(
value: Array, xp: ModuleType, device: Device | None = None
) -> Array:
try:
x = xp.from_dlpack(value)
return to_device(x, device) if device is not None else x
except (RuntimeError, BufferError):
value_namespace = array_namespace(value)
value_copy = value_namespace.asarray(value, copy=True)
return xp.asarray(value_copy, device=device)
def special_to_jax_conversion(value, xp, device):
return jax.numpy.from_dlpack(value, device=device)
backend[jax.numpy] = special_to_jax_conversion
# ... The advantage of this approach is that you can let other packages define efficient conversions by exposing some official
No, there should not be any difference. I thought that you want to keep them around for backward compatibility. Also, the
from functools import partial
import numpy as np
from gymnasium.wrappers.to_array import ToArray
ToNumpy = partial(ToArray, target_xp=np) |
I have addressed most of your comments. In addition, I noticed that some tests failed with jax on systems with a GPU because the test comparison function for Array API objects did not cover some edge cases. Therefore, I have added However, since this one is exclusively used for testing and part of the Array API projects which we depend on anyways, I feel like this is worth including. With it, all tests pass also on systems with GPUs. |
@amacati This looks good to me, but I want @RedTachyon to look at the PR before merging as they were one of the authors of the original data conversion wrappers. Hopefully they will look at it this weekend or next week |
@amacati i might be wrong on this but i realised that an important aspect of the NumpyToTorch wrapper is not just that the wrapper returns torch observations etc, but that you can pass torch actions which are converted to NumPy for the environment. I don't believe the implementation does that. This is why the wrappers are called NumpyToTorch etc not just ToTorch |
@pseudo-rnd-thoughts This is why there is an
It definitely does! The old tests for Gymnasium/gymnasium/wrappers/to_array.py Line 222 in 05d6d82
Taking the I wanted to automate the environment framework detection, but there is no reliable way without resetting the env to get a sample of Actually, this was also a motivation for having observation spaces that use the "correct" framework. That would allow us to use auto-detection through |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, thanks for the PR, I really like the idea of this fairly generic array-to-array conversion. I left a bunch of comments, some important, some nitpicks.
As a caveat: I didn't actually run the new code, I'm trusting the tests on this. The review isn't super comprehensive, but hopefully it'll help catch some potential issues before they actually come up.
@@ -31,7 +31,7 @@ To install the base Gymnasium library, use `pip install gymnasium` | |||
|
|||
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies. | |||
|
|||
We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. | |||
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python 3.9 has some more time until EoL - I think I'd be fine equipping this feature with some explicit warning/exception that says it doesn't work on 3.9, but I'd be very careful removing support for this version just for this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would render the wrappers temporarily unusable for people that previously had access to them. What's your take on this @pseudo-rnd-thoughts ? Also, EOL is in October, so that's not too far off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @RedTachyon, to remove 3.9 due to a very small and uncommonly used part of the project doesn't seem right. Could we raise a warning if the python version is 3.8 if trying to use the wrappers
gymnasium/wrappers/jax_to_numpy.py
Outdated
def render(self) -> RenderFrame | list[RenderFrame] | None: | ||
"""Returns the rendered frames as a numpy array.""" | ||
return jax_to_numpy(self.env.render()) | ||
gym.utils.ezpickle.EzPickle.__init__(self, env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need EzPickle here? Plus does this actually work with arbitrary envs?
If I remember right, the idea behind EzPickle is that you can "fake pickle" an environment, so that later you can reconstruct it in its original form, without necessarily preserving the actual state of the environment. This wrapper doesn't have any arguments, so I'd guess this isn't necessary. But please do check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, but your comment made me realize we need even more than that. We cannot pickle the ToArray
wrapper because it contains references to two modules, self._env_xp
and self._target_xp
. These will prevent pickle from working, because you cannot pickle whole modules. What's more, using them as arguments will also prevent EzPickle
from working correctly. I see two ways around that:
- Using strings instead of modules to create the wrapper (not a fan)
- Creating custom
__setstate__
and__getstate__
functions (should be fairly easy)
The reason I dislike the first option is that using the modules as input will a) force people to only use modules they have actually installed and b) prevent a magic conversion from strings to modules.
What's your opinion on that?
gymnasium/wrappers/to_array.py
Outdated
Returns: | ||
xp-based observations and info | ||
""" | ||
if options: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where things might get a bit dicey. Options are meant to be relatively free-form, and I don't think we should force any specific structure on this object. In particular, values might be strings, or any other exotic objects. It's not obvious to me whether they should be automatically converted. Same thing applies to info dicts. Does the current to_xp
function handle arbitrary objects that might not be easily convertible?
I'd consider adding an extra parameter to this (and the vector version) wrapper, defining the behavior of options/infos/anything else that might be questionable. Else this might cause a lot of weird interactions.
As an example: an option is an int which is then converted to a string, or maybe even directly used as a key in a dictionary. Automatically converting options to arrays might would have some hard to predict consequences.
(side note: I realize that this issue was already present, but might as well fix it/think about it if we're revamping the whole thing anyways)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not. Currently, anything that can't be converted raises an error. There is also no way to specify that e.g. a numpy array in the info dict should not be converted. The list of objects we can convert is limited to:
- Numbers
- Mappings
- Iterables (-> Arrays)
- None
So even a string should raise an error. We could ofc also return objects that are not supported as they are. Alternatively, one could have convert_info
and convert_options
arguments, but I am also in favor of keeping all init args limited and simple.
The latest changes address some of the concerns with pickle, though I'm not 100% satisfied. Any thoughts on what a good solution would look like @RedTachyon ? |
How willing are you to make breaking changes @RedTachyon ? If we make the arguments to |
I haven't read all your comments, will do that tomorrow, but would |
Might be, yes. Or By the way, it seems like |
I like
|
I have rebased my PR to the latest changes in gymnasium and renamed the wrapper and conversion function. The remaining open questions, including concerns raised from @RedTachyon's review, are:
Pickling issuesPickling is now implemented via a custom |
I now fixed all the remaining tests, didn't see that parts of the pipeline were failing. What do you think @pseudo-rnd-thoughts @RedTachyon ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies @amacati for taking way longer than expected to review.
Overall looks good and would be happy to merge with the suggested changes.
Thanks again for your hard work on the PR
@@ -31,7 +31,7 @@ To install the base Gymnasium library, use `pip install gymnasium` | |||
|
|||
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies. | |||
|
|||
We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. | |||
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @RedTachyon, to remove 3.9 due to a very small and uncommonly used part of the project doesn't seem right. Could we raise a warning if the python version is 3.8 if trying to use the wrappers
… inheritance. Add tests for specialized vector ArrayConversion wrappers.
Regarding the 3.10 vs 3.9 issue: The
The wealth of issues this causes and the amount of workarounds needed to make this work definitely feel wrong, and I don't think it is smart to introduce this functionality while still supporting 3.9. |
The last commit simplifies the logic of One thing to be aware of here is that this requires us to import modules dynamically based on their names. As with all things that can be pickled, this is a potential security concern. However, that should be considered out of scope since the use of pickle is assumed anyways, and gymnasium is not intended to be used outside of trusted execution environments. |
Been eagerly watching this PR, it seems like a useful step towards making Gymnasium more library-agnostic. To suggest an idea for the 3.9 vs 3.10 issue, how about creating a short-lived Also, if you haven't seen it already, the guidance from SPEC 0 might be useful for determining which versions of dependencies should be supported. That outlines a fairly aggressive deprecation schedule, but it seems most users accept that they might need to use a slightly out-of-date version of a library if they require a significantly older Python version. |
Thanks for your interest @Jammf. |
I think at this point that is a call the gymnasium maintainers have to make. Either choice is fine really, in the end it's a matter of preference on the update policy. Of course I would like to see this merged as soon as possible since we are relying on some of the fixes for our research projects and also use it in classes that we teach. So just in case that wasn't clear, I'd be happy to see the bump to 3.10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about it and with other PRs that we are going to have to make (moving MuJoCo-Py to Gymnasium-Robotics and adding Python 3.13 supports) means that I think we can add this now and go with the pain of dropping Python 3.8 and 3.9. My plan is to release Gymnasium 1.2 by the end of the month.
Thanks for your patience and hard work on the PR @amacati
Description
Motivation
Recently, environments implemented in frameworks such as
torch
andjax
have gained importance as they provide a convenient way to run parallelized environments on GPUs and compile parts of the logic. To enable the mixing of environments written in one framework and training in another framework,gymnasium
provides wrappers that automatically convert from one framework to the other and back. However, only a few combinations are currently implemented, namely JaxToNumpy, JaxToTorch and NumpyToTorch. Other combinations, e.g. an environment implemented intorch
and converting tojax
arrays, are missing.Proposal
This PR proposes to add a generic
ToArray
wrapper from one array framework to the other. It leverages the Python array API standard (see here), which defines common operations that any framework has to implement in order to be compatible. Frameworks that are currently not fully compatible are wrapped in a compatibility layer (see array-api-compat). With this design, the wrapper can handle any combination of supported libraries. Notably, this includes previously supported frameworks such asnumpy
andtorch
, as well as new ones such ascupy
.Advantages
One major advantage is that gymnasium can support conversions between arbitrary array API compatible frameworks without additional efforts on our side. As soon as a library is compatible with the standard, it automatically works with the new wrapper.
Furthermore, the new
to_xp
conversion function to an arbitrary, array API compatible frameworkxp
only has to be defined once. This significantly reduces code duplication (compare the near-identical conversions in jax_to_torch.py and numpy_to_torch.py). This change is fully backwards compatible. We can retain the old wrappers as special cases of theToArray
wrapper, and keep the conversion functions.As an example,
numpy_to_torch
can now be written aswhich reduces it to a special case of the generic
to_xp
function while maintaining backwards compatibility.The new wrapper also allows more control over the array devices. Up to now, it was e.g. not possible to specify a device for a jax environment in
JaxToTorch
.Disadvantages
While many libraries are actively working towards supporting the array API, some have not achieved full compatibility yet, e.g.
torch
. The array API consortium provides a compatibility layer for these cases. In order to support the core frameworks, this introduces a new dependency onarray-api-compat
.In addition, the wrapper requires frameworks to support at least the 2023 version of the standard. NumPy added support for this version in numpy 2.1. Therefore, installing
gymnasium
with its torch or jax dependencies will raise the required numpy version to 2.1.Note: NumPy 2.1 only supports Python>=3.10! Consequently, this also raises the Python requirement for
gymnasium[jax,torch]
!Example
Convert from a
torch
environment tojax
using theToArray
wrapper.Type of change
Dependency changes
As explained in the previous sections, this PR requires the following version updates:
numpy>=2.1
forgymnasium[torch,jax,array-api]
python>=3.10
caused bynumpy>=2.1
array-api-compat
forgymnasium[torch,jax,array-api]
Blockers
The current PR will fail the test suite on Python 3.9. If we are fine with supporting
torch
,jax
etc. for python 3.10 and onward only, this PR can move ahead. Otherwise, this PR is blocked until we officially drop support for 3.9.Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up)