1
+ import numpy as np
2
+ import jax
3
+ import gymnasium as gym
4
+ import torch
5
+ from dataclasses import asdict
6
+ from brax .io import torch as brax_torch
7
+
8
+ class TorchWrapper (gym .Wrapper ):
9
+ """Wrapper that converts Jax tensors to PyTorch tensors."""
10
+
11
+ def __init__ (self , env , device ):
12
+ super ().__init__ (env )
13
+ self .device = device
14
+ self .env = env
15
+ self .default_params = env .default_params
16
+ self .metadata = {
17
+ 'render.modes' : ['human' , 'rgb_array' ],
18
+ }
19
+
20
+ # define obs and action space
21
+ obs_shape = env .observation_space (self .default_params ).shape
22
+ self .observation_space = gym .spaces .Box (
23
+ low = - 1e6 , high = 1e6 , shape = obs_shape )
24
+ self .action_space = gym .spaces .Discrete (env .action_space (self .default_params ).n )
25
+
26
+ # jit the reset function
27
+ def reset (key ):
28
+ key1 , key2 = jax .random .split (key )
29
+ obs , state = self .env .reset (key2 )
30
+ return state , obs , key1 , asdict (state )
31
+ self ._reset = jax .jit (reset )
32
+
33
+ # jit the step function
34
+ def step (state , action ):
35
+ obs , env_state , reward , done , info = self .env .step (rng = self ._key , state = state , action = action )
36
+ return env_state , obs , reward , done , {** asdict (env_state ), ** info }
37
+ self ._step = jax .jit (step )
38
+
39
+ def reset (self , seed = 0 , options = None ):
40
+ self .seed (seed )
41
+ self ._state , obs , self ._key , info = self ._reset (self ._key )
42
+ return brax_torch .jax_to_torch (obs , device = self .device ), info
43
+
44
+ def step (self , action ):
45
+ action = brax_torch .torch_to_jax (action )
46
+ self ._state , obs , reward , done , info = self ._step (self ._state , action )
47
+ obs = brax_torch .jax_to_torch (obs , device = self .device )
48
+ reward = brax_torch .jax_to_torch (reward , device = self .device )
49
+ terminateds = brax_torch .jax_to_torch (done , device = self .device )
50
+ truncateds = brax_torch .jax_to_torch (done , device = self .device )
51
+ info = brax_torch .jax_to_torch (info , device = self .device )
52
+ return obs , reward , terminateds , truncateds , info
53
+
54
+ def seed (self , seed : int = 0 ):
55
+ self ._key = jax .random .PRNGKey (seed )
56
+
57
+ class ResizeTorchWrapper (gym .Wrapper ):
58
+ """Wrapper that resizes observations to a given shape."""
59
+
60
+ def __init__ (self , env , shape ):
61
+ super ().__init__ (env )
62
+ self .env = env
63
+ num_channels = env .observation_space .shape [- 1 ]
64
+ self .shape = (num_channels , shape [0 ], shape [1 ])
65
+
66
+ # define obs and action space
67
+ self .observation_space = gym .spaces .Box (
68
+ low = - 1e6 , high = 1e6 , shape = self .shape )
69
+
70
+ def reset (self , seed = 0 , options = None ):
71
+ obs , info = self .env .reset (seed , options )
72
+ obs = obs .permute (0 , 3 , 1 , 2 )
73
+ obs = torch .nn .functional .interpolate (obs , size = self .shape [1 :], mode = 'nearest' )
74
+ return obs , info
75
+
76
+ def step (self , action ):
77
+ obs , reward , terminateds , truncateds , info = self .env .step (action )
78
+ obs = obs .permute (0 , 3 , 1 , 2 )
79
+ obs = torch .nn .functional .interpolate (obs , size = self .shape [1 :], mode = 'nearest' )
80
+ return obs , reward , terminateds , truncateds , info
81
+
82
+ class RecordEpisodeStatistics4Craftax (gym .Wrapper ):
83
+ def __init__ (self , env : gym .Env , deque_size : int = 100 ) -> None :
84
+ super ().__init__ (env )
85
+ self .num_envs = getattr (env , "num_envs" , 1 )
86
+ self .episode_returns = None
87
+ self .episode_lengths = None
88
+
89
+ def reset (self , ** kwargs ):
90
+ observations , infos = super ().reset (** kwargs )
91
+ self .episode_returns = np .zeros (self .num_envs , dtype = np .float32 )
92
+ self .episode_lengths = np .zeros (self .num_envs , dtype = np .int32 )
93
+ self .returned_episode_returns = np .zeros (self .num_envs , dtype = np .float32 )
94
+ self .returned_episode_lengths = np .zeros (self .num_envs , dtype = np .int32 )
95
+ return observations , infos
96
+
97
+ def step (self , actions ):
98
+ observations , rewards , terms , truncs , infos = super ().step (actions )
99
+ self .episode_returns += rewards .cpu ().numpy ()
100
+ self .episode_lengths += 1
101
+ self .returned_episode_returns [:] = self .episode_returns
102
+ self .returned_episode_lengths [:] = self .episode_lengths
103
+ self .episode_returns *= 1 - infos ["returned_episode" ].cpu ().numpy ().astype (np .int32 )
104
+ self .episode_lengths *= 1 - infos ["returned_episode" ].cpu ().numpy ().astype (np .int32 )
105
+ infos ["episode" ] = {}
106
+ infos ["episode" ]["r" ] = self .returned_episode_returns
107
+ infos ["episode" ]["l" ] = self .returned_episode_lengths
108
+
109
+ for idx , d in enumerate (terms ):
110
+ if not d :
111
+ infos ["episode" ]["r" ][idx ] = 0
112
+ infos ["episode" ]["l" ][idx ] = 0
113
+
114
+ return observations , rewards , terms , truncs , infos
0 commit comments