-
-
Notifications
You must be signed in to change notification settings - Fork 991
/
Copy pathjax_to_numpy.py
74 lines (56 loc) · 2.97 KB
/
jax_to_numpy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Helper functions and wrapper class for converting between numpy and Jax."""
from __future__ import annotations
import functools
import numpy as np
import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.error import DependencyNotInstalled
from gymnasium.wrappers.to_array import ToArray, to_xp
try:
import jax.numpy as jnp
except ImportError:
raise DependencyNotInstalled(
'Jax is not installed therefore cannot call `numpy_to_jax`, run `pip install "gymnasium[jax]"`'
)
__all__ = ["JaxToNumpy", "jax_to_numpy", "numpy_to_jax"]
jax_to_numpy = functools.partial(to_xp, xp=np)
numpy_to_jax = functools.partial(to_xp, xp=jnp)
class JaxToNumpy(ToArray, gym.utils.RecordConstructorArgs, gym.utils.ezpickle.EzPickle):
"""Wraps a Jax-based environment such that it can be interacted with NumPy arrays.
Actions must be provided as numpy arrays and observations will be returned as numpy arrays.
A vector version of the wrapper exists, :class:`gymnasium.wrappers.vector.JaxToNumpy`.
Notes:
The Jax To Numpy and Numpy to Jax conversion does not guarantee a roundtrip (jax -> numpy -> jax) and vice versa.
The reason for this is jax does not support non-array values, therefore numpy ``int_32(5) -> DeviceArray([5], dtype=jnp.int23)``
Example:
>>> import gymnasium as gym # doctest: +SKIP
>>> env = gym.make("JaxEnv-vx") # doctest: +SKIP
>>> env = JaxToNumpy(env) # doctest: +SKIP
>>> obs, _ = env.reset(seed=123) # doctest: +SKIP
>>> type(obs) # doctest: +SKIP
<class 'numpy.ndarray'>
>>> action = env.action_space.sample() # doctest: +SKIP
>>> obs, reward, terminated, truncated, info = env.step(action) # doctest: +SKIP
>>> type(obs) # doctest: +SKIP
<class 'numpy.ndarray'>
>>> type(reward) # doctest: +SKIP
<class 'float'>
>>> type(terminated) # doctest: +SKIP
<class 'bool'>
>>> type(truncated) # doctest: +SKIP
<class 'bool'>
Change logs:
* v1.0.0 - Initially added
"""
def __init__(self, env: gym.Env[ObsType, ActType]):
"""Wraps a jax environment such that the input and outputs are numpy arrays.
Args:
env: the jax environment to wrap
"""
if jnp is None:
raise DependencyNotInstalled(
'Jax is not installed, run `pip install "gymnasium[jax]"`'
)
gym.utils.RecordConstructorArgs.__init__(self)
gym.utils.ezpickle.EzPickle.__init__(self, env)
super().__init__(env=env, env_xp=jnp, target_xp=np)