Skip to content

Commit 92ad2c7

Browse files
committed
fix: replace jax.tree_* with jax.tree
1 parent ccfde71 commit 92ad2c7

6 files changed

Lines changed: 10 additions & 12 deletions

File tree

examples/purejaxrl/ppo_minigrid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,14 @@ def _loss_fn(params, traj_batch, gae, targets):
225225
), "batch size must be equal to number of steps * number of envs"
226226
permutation = jax.random.permutation(_rng, batch_size)
227227
batch = (traj_batch, advantages, targets)
228-
batch = jax.tree_util.tree_map(
228+
batch = jax.tree.map(
229229
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
230230
)
231-
shuffled_batch = jax.tree_util.tree_map(
231+
shuffled_batch = jax.tree.map(
232232
lambda x: jnp.take(x, permutation, axis=0), batch
233233
)
234234
# Mini-batch Updates
235-
minibatches = jax.tree_util.tree_map(
235+
minibatches = jax.tree.map(
236236
lambda x: jnp.reshape(
237237
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
238238
),

navix/agents/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,15 @@ def update(self, train_state: TrainingState, _) -> Tuple[TrainingState, Dict]:
252252
), "batch size must be equal to number of steps * number of envs"
253253
permutation = jax.random.permutation(rng_1, n_samples)
254254
samples = (experience, advantages, targets, values) # (T, N, ...)
255-
samples = jax.tree_util.tree_map(
255+
samples = jax.tree.map(
256256
lambda x: x.reshape((n_samples,) + x.shape[2:]), samples
257257
) # (T * N, ...)
258-
shuffled_batch = jax.tree_util.tree_map(
258+
shuffled_batch = jax.tree.map(
259259
lambda x: jnp.take(x, permutation, axis=0), samples
260260
) # (T * N, ...)
261261

262262
# One epoch update over all mini-batches
263-
minibatches = jax.tree_util.tree_map(
263+
minibatches = jax.tree.map(
264264
lambda x: jnp.reshape(
265265
x, (self.hparams.num_minibatches, -1) + tuple(x.shape[1:])
266266
),

navix/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class Entity(Positionable, HasTag, HasSprite):
6666
To create an entity, use the `create` method."""
6767

6868
def __getitem__(self: T, idx) -> T:
69-
return jax.tree_util.tree_map(lambda x: x[idx], self)
69+
return jax.tree.map(lambda x: x[idx], self)
7070

7171
@property
7272
def name(self) -> str:

navix/environments/key_corridor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import jax
2525
import jax.numpy as jnp
2626
from jax import Array
27-
import jax.tree_util as jtu
2827

2928
from navix import observations, rewards, terminations
3029

@@ -116,7 +115,7 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
116115
open=jnp.asarray(0),
117116
)
118117
)
119-
doors = jtu.tree_map(lambda *x: jnp.stack(x), *doors)
118+
doors = jax.tree.map(lambda *x: jnp.stack(x), *doors)
120119

121120
entities = {
122121
"player": player[None],

navix/experiment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,10 @@ def search(hparam_set_sample):
171171
# average over seeds
172172
for i in range(len_search_set):
173173
print("Logging results for hparam set:", search_set)
174-
hparams = jax.tree_map(lambda x: x[i], search_set)
174+
hparams = jax.tree.map(lambda x: x[i], search_set)
175175
config = {**vars(self), **asdict(hparams)}
176176
wandb.init(project=self.name, config=config, group=self.group)
177-
log = jax.tree_map(lambda x: jnp.mean(x[i], axis=0), logs)
177+
log = jax.tree.map(lambda x: jnp.mean(x[i], axis=0), logs)
178178
self.agent.log_on_train_end(log)
179179
wandb.finish()
180180
logging_time = time.time() - start_time

navix/grid.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import jax
2727
import jax.numpy as jnp
2828
from jax import Array
29-
import jax.tree_util as jtu
3029
from flax import struct
3130

3231

0 commit comments

Comments
 (0)