Skip to content

Add generic conversion wrapper between Array API compatible frameworks #1333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

amacati
Copy link
Contributor

@amacati amacati commented Mar 19, 2025

Description

Motivation

Recently, environments implemented in frameworks such as torch and jax have gained importance as they provide a convenient way to run parallelized environments on GPUs and compile parts of the logic. To enable the mixing of environments written in one framework and training in another framework, gymnasium provides wrappers that automatically convert from one framework to the other and back. However, only a few combinations are currently implemented, namely JaxToNumpy, JaxToTorch and NumpyToTorch. Other combinations, e.g. an environment implemented in torch and converting to jax arrays, are missing.

Proposal

This PR proposes to add a generic ToArray wrapper from one array framework to the other. It leverages the Python array API standard (see here), which defines common operations that any framework has to implement in order to be compatible. Frameworks that are currently not fully compatible are wrapped in a compatibility layer (see array-api-compat). With this design, the wrapper can handle any combination of supported libraries. Notably, this includes previously supported frameworks such as numpy and torch, as well as new ones such as cupy.

Advantages

One major advantage is that gymnasium can support conversions between arbitrary array API compatible frameworks without additional efforts on our side. As soon as a library is compatible with the standard, it automatically works with the new wrapper.

Furthermore, the new to_xp conversion function to an arbitrary, array API compatible framework xp only has to be defined once. This significantly reduces code duplication (compare the near-identical conversions in jax_to_torch.py and numpy_to_torch.py). This change is fully backwards compatible. We can retain the old wrappers as special cases of the ToArray wrapper, and keep the conversion functions.

As an example, numpy_to_torch can now be written as

numpy_to_torch = functools.partial(to_xp, xp=torch)

which reduces it to a special case of the generic to_xp function while maintaining backwards compatibility.

The new wrapper also allows more control over the array devices. Up to now, it was e.g. not possible to specify a device for a jax environment in JaxToTorch.

Disadvantages

While many libraries are actively working towards supporting the array API, some have not achieved full compatibility yet, e.g. torch. The array API consortium provides a compatibility layer for these cases. In order to support the core frameworks, this introduces a new dependency on array-api-compat.

In addition, the wrapper requires frameworks to support at least the 2023 version of the standard. NumPy added support for this version in numpy 2.1. Therefore, installing gymnasium with its torch or jax dependencies will raise the required numpy version to 2.1.

Note: NumPy 2.1 only supports Python>=3.10! Consequently, this also raises the Python requirement for gymnasium[jax,torch]!

Example

Convert from a torch environment to jax using the ToArray wrapper.

import torch
import jax.numpy as jp
import gymnasium as gym


env = gym.make("JaxEnv-vx")
env = ToArray(env, env_xp=torch, target_xp=jp)
obs, _ = env.reset(seed=123)
type(obs)  # <class 'jaxlib.xla_extension.ArrayImpl'>

action = jp.array(env.action_space.sample())
obs, reward, terminated, truncated, info = env.step(action)
type(obs)  # <class 'jaxlib.xla_extension.ArrayImpl'>
type(reward)  # <class 'float'>
type(terminated)  # <class 'bool'>
type(truncated)  # <class 'bool'>

Type of change

  • New feature (non-breaking change which adds functionality)

Dependency changes

As explained in the previous sections, this PR requires the following version updates:

  • numpy>=2.1 for gymnasium[torch,jax,array-api]
  • python>=3.10 caused by numpy>=2.1
  • array-api-compat for gymnasium[torch,jax,array-api]

Blockers

The current PR will fail the test suite on Python 3.9. If we are fine with supporting torch, jax etc. for python 3.10 and onward only, this PR can move ahead. Otherwise, this PR is blocked until we officially drop support for 3.9.

Checklist:

  • I have run the pre-commit checks with pre-commit run --all-files (see CONTRIBUTING.md instructions to set it up)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@amacati
Copy link
Contributor Author

amacati commented Mar 19, 2025

@pseudo-rnd-thoughts Any thoughts on how we might resolve the issue with Python 3.9? I think this wrapper makes a lot of sense to provide a consistent way of switching between array frameworks, but obviously the CI cannot fail for Python 3.9. As far as I can see, the gymnasium team would need to make a call on whether or not it's okay to drop 3.9 support for jax and torch. Not sure if lower python versions are required for compatibility, given that the addition of torch and jax in gymnasium are fairly new

@pseudo-rnd-thoughts
Copy link
Member

@amacati Thanks for the PR, sorry I must have forgotten to review it. Yes, this looks great.
For the py 3.9 issue, I can't see in the code what requires numpy==2.1.2 causing the problem.
Otherwise, can we just set numpy<2.1 as suggested for now?

@amacati
Copy link
Contributor Author

amacati commented Mar 19, 2025

No worries, I just opened this.

The dependency issue is not immediately obvious. The array API gets extended in yearly editions by the consortium once features are stabilized and deemed essential for all frameworks. The wrapper makes use of array API features that have been stabilized in the 2023 edition. However, numpy only supports those from numpy 2.1 onwards (see the release notes, first highlight at the very top).

Since I deem numpy the core array API framework for gymnasium, I would not want to ship conversion wrappers that are broken for numpy. It should at least be communicated clearly, and ideally we could specify it in the dependencies.

In short: yes, we can allow numpy<2.1, but this would render JaxToNumpy, NumpyToJax and NumpyToTorch unusable on Python 3.9.

@amacati
Copy link
Contributor Author

amacati commented Mar 19, 2025

As a side note, is there any way to modify the observation and action spaces so that they automatically return arrays in the correct framework? This is something I have encountered quite frequently. It's rather unintuitive that the spaces return numpy arrays for, e.g., a jax environment. After all, a "true" observation would be a jax array.

This might warrant its own PR, just came to my mind while writing the wrapper.

@pseudo-rnd-thoughts
Copy link
Member

In short: yes, we can allow numpy<2.1, but this would render JaxToNumpy, NumpyToJax and NumpyToTorch unusable on Python 3.9.

Ahh, ok. With the release v1.1, then I think we can move to supporting python 3.10+

As a side note, is there any way to modify the observation and action spaces so that they automatically return arrays in the correct framework? This is something I have encountered quite frequently. It's rather unintuitive that the spaces return numpy arrays for, e.g., a jax environment. After all, a "true" observation would be a jax array.

Not that it automatically returns arrays in the correct frameskip but we added a feature for metadata["jax"] = True and metadata["torch"] = True however we don't have any documentation about it currently (

if env.metadata.get("jax", False):
)

@amacati
Copy link
Contributor Author

amacati commented Mar 19, 2025

Oh no, I didn't mean the wrapper around the environment. I meant the actual gymnasium.spaces from self.observation_space and self.action_space. Would be nice if we could support self.action_space.to(xp) or something like that, where xp is torch, jax, cupy, ... Sampling from the space would then yield arrays in the correct framework.

I.e.

type(env.observation_space.sample())  # >>> torch.tensor
torch.tensor([1.0, 2.0]) in env.observation_space  # >>> True

Regarding gymnasium 1.1: That means we are waiting for the next release, leave this PR open until then, and merge it with a numpy>=2.1 requirement?

@pseudo-rnd-thoughts
Copy link
Member

Oh no, I didn't mean the wrapper around the environment. I meant the actual gymnasium.spaces from self.observation_space and self.action_space. Would be nice if we could support self.action_space.to(xp) ...

Ohh I see what you mean. In my head, I would want to modify the wrappers.NumpyToJax.observation_space such that when sample is called, we can use the numpy_to_jax function to make sure this is a jax tensor.
It might be as simple as use modifying the definition of sample for the observation space as a new function that wraps the data.
Does that make sense? Though this should be a different PR

Regarding gymnasium 1.1: That means we are waiting for the next release, leave this PR open until then, and merge it with a numpy>=2.1 requirement?

I meant that with Gymnasium v1.1 fixing bugs in Gymnasium v1.0 I'm happy for v1.2 (next release) to significantly reduce the python version from 3.8+ to 3.10+

@amacati
Copy link
Contributor Author

amacati commented Mar 19, 2025

It might be as simple as use modifying the definition of sample for the observation space as a new function that wraps the data.
Does that make sense? Though this should be a different PR

Yes, makes perfect sense. Though we'd also have to redefine a bunch of other stuff to make x in obs_space work. It makes a lot of sense to also do that in the ToArray wrapper and auto-convert action and observation spaces. In addition, we should probably expose the patching function for environments that do not need conversions, but still want to have the correct spaces behaviour (torch.tensor -> torch.tensor with sampling as torch.tensor and x in space checks for torch.tensors).

I will keep this in mind and wait for this PR to land before working on it.

I meant that with Gymnasium v1.1 fixing bugs in Gymnasium v1.0 I'm happy for v1.2 (next release) to significantly reduce the python version from 3.8+ to 3.10+

Alright, looking forward to the release of 1.1 then to pick this up again!

@pseudo-rnd-thoughts
Copy link
Member

pseudo-rnd-thoughts commented Mar 19, 2025

Yes, makes perfect sense. Though we'd also have to redefine a bunch of other stuff to make x in obs_space work. It makes a lot of sense to also do that in the ToArray wrapper and auto-convert action and observation spaces. In addition, we should probably expose the patching function for environments that do not need conversions, but still want to have the correct spaces behaviour (torch.tensor -> torch.tensor with sampling as torch.tensor and x in space checks for torch.tensors).

Ahh yes, this is why I hadn't done this previously, because you would need to implement in and sample at a minimum. The easiest way would be to convert between types which is slow relative, therefore, people need to be aware in order to minimise it

Also, we already cut v1.1 two weeks ago but with minimal publicity (https://github.com/Farama-Foundation/Gymnasium/releases/tag/v1.1.0)

@amacati
Copy link
Contributor Author

amacati commented Mar 19, 2025

If 1.1 is out, should I update the Python requirement for the whole package to 3.10 to move forward? Or do you want to do this in a separate PR?

@pseudo-rnd-thoughts
Copy link
Member

If 1.1 is out, should I update the Python requirement for the whole package to 3.10 to move forward? Or do you want to do this in a separate PR?

Update to 3.10 here

@amacati
Copy link
Contributor Author

amacati commented Mar 20, 2025

I have increased the version requirements, addressed a bug in the vectorized version of ToArray, and added tests for the vector wrapper. Should be ready for a first review.

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. I like that this simplifies implementations.

My main question is if particular implementations, e.g., jax, wants to implement a new conversion functions is that possible?

Is there any difference between the NumpyToTorch and JaxToTorch? Could we just have a ToNumpy, ToJax, and ToTorch wrappers?

Also could you add this to the documentation (https://github.com/Farama-Foundation/Gymnasium/blob/main/docs/api/wrappers/misc_wrappers.md and https://github.com/Farama-Foundation/Gymnasium/blob/main/docs/api/wrappers/table.md)

@amacati
Copy link
Contributor Author

amacati commented Mar 24, 2025

My main question is if particular implementations, e.g., jax, wants to implement a new conversion functions is that possible?

At the moment: No. However, we can easily integrate that in the future if we want to. I'm currently working on a similar PR for scipy where we already want to support multiple implementations. The solution there is to have a backend which is really just a dictionary. When conversion functions are invoked with Array objects, we check if anyone has registered a special backend for that particular type of object. If not, we fall back to the generic Array API backend.

We can then either register custom backends on our side, or allow users to define their own, specialized versions of conversions. To give you a rough idea, this would look approximately like so

# ...

backend = {}

@to_xp.register(abc.Iterable)
def _iterable_to_xp(
    value: Iterable[Any], xp: ModuleType, device: Device | None = None
) -> Iterable[Any]:
    if is_array_api_obj(value):
        return backend.get(xp, _array_api_to_xp)(value, xp, device)
    ...

def _array_api_to_xp(
    value: Array, xp: ModuleType, device: Device | None = None
) -> Array:
    try:
        x = xp.from_dlpack(value)
        return to_device(x, device) if device is not None else x
    except (RuntimeError, BufferError):
        value_namespace = array_namespace(value)
        value_copy = value_namespace.asarray(value, copy=True)
        return xp.asarray(value_copy, device=device)

def special_to_jax_conversion(value, xp, device):
    return jax.numpy.from_dlpack(value, device=device)

backend[jax.numpy] = special_to_jax_conversion

# ...

The advantage of this approach is that you can let other packages define efficient conversions by exposing some official register_conversion_backend function or similar.

Is there any difference between the NumpyToTorch and JaxToTorch? Could we just have a ToNumpy, ToJax, and ToTorch wrappers?

No, there should not be any difference. I thought that you want to keep them around for backward compatibility. Also, the ToArray wrapper encompasses all of them.

ToNumpy would basically look like this:

from functools import partial

import numpy as np

from gymnasium.wrappers.to_array import ToArray

ToNumpy = partial(ToArray, target_xp=np)

Add ToArray wrapper to docs
Add array-api-extra for reliable testing
Fix smaller mistakes in the docs
Add NumPy version check for Array API wrappers
@amacati
Copy link
Contributor Author

amacati commented Mar 24, 2025

I have addressed most of your comments. In addition, I noticed that some tests failed with jax on systems with a GPU because the test comparison function for Array API objects did not cover some edge cases. Therefore, I have added array-api-extra as a testing dependency, which handles all cases consistently. I commented on this in the previous implementation already, but was hesitant to introduce another dependency.

However, since this one is exclusively used for testing and part of the Array API projects which we depend on anyways, I feel like this is worth including. With it, all tests pass also on systems with GPUs.

@pseudo-rnd-thoughts
Copy link
Member

pseudo-rnd-thoughts commented Mar 28, 2025

@amacati This looks good to me, but I want @RedTachyon to look at the PR before merging as they were one of the authors of the original data conversion wrappers. Hopefully they will look at it this weekend or next week

@pseudo-rnd-thoughts
Copy link
Member

@amacati i might be wrong on this but i realised that an important aspect of the NumpyToTorch wrapper is not just that the wrapper returns torch observations etc, but that you can pass torch actions which are converted to NumPy for the environment. I don't believe the implementation does that. This is why the wrappers are called NumpyToTorch etc not just ToTorch

@amacati
Copy link
Contributor Author

amacati commented Mar 29, 2025

@pseudo-rnd-thoughts This is why there is an env_xp and a target_xp. env_xp converts actions to the environment framework, target_xp converts them to the framework you want the observations to reside in.

I don't believe the implementation does that. This is why the wrappers are called NumpyToTorch etc not just ToTorch

It definitely does! The old tests for NumpyToTorch etc. are still in place and are all passing, even though these wrappers have been rewritten as a special case of the ToArray wrapper. The action conversion is happening here:

action = to_xp(action, xp=self._env_xp, device=self._env_device)

Taking the ToNumpy wrapper example from above, you would always need to specify the environment framework, i.e. ToNumpy(env, env_xp=torch) for an environment that is using torch under the hood.

I wanted to automate the environment framework detection, but there is no reliable way without resetting the env to get a sample of obs. Resetting without explicit user request seems like a bad idea though, so I made the env_xp framework explicit as an argument.

Actually, this was also a motivation for having observation spaces that use the "correct" framework. That would allow us to use auto-detection through observation_space.sample

Copy link
Member

@RedTachyon RedTachyon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the PR, I really like the idea of this fairly generic array-to-array conversion. I left a bunch of comments, some important, some nitpicks.

As a caveat: I didn't actually run the new code, I'm trusting the tests on this. The review isn't super comprehensive, but hopefully it'll help catch some potential issues before they actually come up.

@@ -31,7 +31,7 @@ To install the base Gymnasium library, use `pip install gymnasium`

This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install "gymnasium[atari]"` or use `pip install "gymnasium[all]"` to install all dependencies.

We support and test for Python 3.8, 3.9, 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
We support and test for Python 3.10, 3.11 and 3.12 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python 3.9 has some more time until EoL - I think I'd be fine equipping this feature with some explicit warning/exception that says it doesn't work on 3.9, but I'd be very careful removing support for this version just for this feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would render the wrappers temporarily unusable for people that previously had access to them. What's your take on this @pseudo-rnd-thoughts ? Also, EOL is in October, so that's not too far off

def render(self) -> RenderFrame | list[RenderFrame] | None:
"""Returns the rendered frames as a numpy array."""
return jax_to_numpy(self.env.render())
gym.utils.ezpickle.EzPickle.__init__(self, env)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need EzPickle here? Plus does this actually work with arbitrary envs?

If I remember right, the idea behind EzPickle is that you can "fake pickle" an environment, so that later you can reconstruct it in its original form, without necessarily preserving the actual state of the environment. This wrapper doesn't have any arguments, so I'd guess this isn't necessary. But please do check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, but your comment made me realize we need even more than that. We cannot pickle the ToArray wrapper because it contains references to two modules, self._env_xp and self._target_xp. These will prevent pickle from working, because you cannot pickle whole modules. What's more, using them as arguments will also prevent EzPickle from working correctly. I see two ways around that:

  • Using strings instead of modules to create the wrapper (not a fan)
  • Creating custom __setstate__ and __getstate__ functions (should be fairly easy)

The reason I dislike the first option is that using the modules as input will a) force people to only use modules they have actually installed and b) prevent a magic conversion from strings to modules.

What's your opinion on that?

"""Returns the rendered frames as a numpy array."""
return jax_to_numpy(self.env.render())
gym.utils.ezpickle.EzPickle.__init__(self, env)
super().__init__(env=env, env_xp=jnp, target_xp=np)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're doing an extremely cursed multiple inheritance here, so to maintain some sanity here - I think it'd be better here to specify the parent class instead of super()? Since the arguments are very specific, and super() can act unpredictably with multiple inheritance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment about pickle. The multiple inheritance issue would go away completely. Also, all inheritance can be moved into the ToArray wrapper, so that all special wrappers only inherit from ToArray.

gym.utils.RecordConstructorArgs.__init__(self, device=device)
gym.Wrapper.__init__(self, env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.utils.ezpickle.EzPickle.__init__(self, env, device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments in jax_to_numpy


# TODO: Device was part of the public API, but should be removed in favor of _env_device and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this meant to be kept for a future PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch. This was meant as a comment for reviews. JaxToTorch previously had a device attribute, which would be removed with the new ToArray wrapper in favor of _env_device and _target_device. These are private however, whereas device was public. Removing it could break existing code. The question was if you are willing to make that breaking change.


return jax_to_numpy(self.env.reset(seed=seed, options=options))
gym.utils.RecordConstructorArgs.__init__(self)
gym.utils.ezpickle.EzPickle.__init__(self, env)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See jax_to_numpy.py

@@ -45,48 +48,8 @@ def __init__(self, env: VectorEnv, device: Device | None = None):
env: The NumPy-based vector environment to wrap
device: The device the torch Tensors should be moved to
"""
super().__init__(env)
gym.utils.RecordConstructorArgs.__init__(self)
gym.utils.ezpickle.EzPickle.__init__(self, env, device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See jax_to_numpy.py (side note: I just noticed that wrappers pointing to torch do actually have the extra argument of the device - still, I'm not sure if it's the right pattern for conversion wrappers)

Returns:
xp-based observations and info
"""
if options:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the non-vectorized to_array.py

Returns:
xp-based observations and info
"""
if options:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where things might get a bit dicey. Options are meant to be relatively free-form, and I don't think we should force any specific structure on this object. In particular, values might be strings, or any other exotic objects. It's not obvious to me whether they should be automatically converted. Same thing applies to info dicts. Does the current to_xp function handle arbitrary objects that might not be easily convertible?

I'd consider adding an extra parameter to this (and the vector version) wrapper, defining the behavior of options/infos/anything else that might be questionable. Else this might cause a lot of weird interactions.

As an example: an option is an int which is then converted to a string, or maybe even directly used as a key in a dictionary. Automatically converting options to arrays might would have some hard to predict consequences.

(side note: I realize that this issue was already present, but might as well fix it/think about it if we're revamping the whole thing anyways)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not. Currently, anything that can't be converted raises an error. There is also no way to specify that e.g. a numpy array in the info dict should not be converted. The list of objects we can convert is limited to:

  • Numbers
  • Mappings
  • Iterables (-> Arrays)
  • None

So even a string should raise an error. We could ofc also return objects that are not supported as they are. Alternatively, one could have convert_info and convert_options arguments, but I am also in favor of keeping all init args limited and simple.

@@ -8,16 +8,14 @@ build-backend = "setuptools.build_meta"
name = "gymnasium"
description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)."
readme = "README.md"
requires-python = ">= 3.8"
requires-python = ">= 3.10"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments in another file about bumping the version requirements

@amacati
Copy link
Contributor Author

amacati commented Apr 9, 2025

The latest changes address some of the concerns with pickle, though I'm not 100% satisfied. ToArray now uses its own __setstate__ and __getstate__ functions, but this might break EzPickle if it wraps around other wrappers.

Any thoughts on what a good solution would look like @RedTachyon ?

@amacati
Copy link
Contributor Author

amacati commented Apr 14, 2025

How willing are you to make breaking changes @RedTachyon ? If we make the arguments to NumpyToTorch etc. the same as to ToArray, there is a pretty straightforward way of making everything work cleanly with pickle. But that would change their signature

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants