Skip to content

Commit 0e263d1

Browse files
authored
Merge pull request #103 from epignatelli/i92
Upgrade to Python>=3.11
2 parents 5fc4975 + 92ad2c7 commit 0e263d1

12 files changed

Lines changed: 60 additions & 39 deletions

File tree

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ jobs:
1212
max-parallel: 5
1313
matrix:
1414
os: ["ubuntu"]
15+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
1516
continue-on-error: false
1617
steps:
1718
- uses: actions/checkout@v3
1819
- uses: actions/setup-python@v4
1920
with:
20-
python-version: "3.10"
21+
python-version: ${{ matrix.python-version }}
2122
- name: Setup navix
2223
run: |
2324
pip install . -v

examples/purejaxrl/ppo_minigrid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,14 @@ def _loss_fn(params, traj_batch, gae, targets):
225225
), "batch size must be equal to number of steps * number of envs"
226226
permutation = jax.random.permutation(_rng, batch_size)
227227
batch = (traj_batch, advantages, targets)
228-
batch = jax.tree_util.tree_map(
228+
batch = jax.tree.map(
229229
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
230230
)
231-
shuffled_batch = jax.tree_util.tree_map(
231+
shuffled_batch = jax.tree.map(
232232
lambda x: jnp.take(x, permutation, axis=0), batch
233233
)
234234
# Mini-batch Updates
235-
minibatches = jax.tree_util.tree_map(
235+
minibatches = jax.tree.map(
236236
lambda x: jnp.reshape(
237237
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
238238
),

navix/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@
1818
# under the License.
1919

2020

21-
__version__ = "0.7.0"
21+
__version__ = "0.7.1"
2222
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())

navix/agents/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,15 @@ def update(self, train_state: TrainingState, _) -> Tuple[TrainingState, Dict]:
252252
), "batch size must be equal to number of steps * number of envs"
253253
permutation = jax.random.permutation(rng_1, n_samples)
254254
samples = (experience, advantages, targets, values) # (T, N, ...)
255-
samples = jax.tree_util.tree_map(
255+
samples = jax.tree.map(
256256
lambda x: x.reshape((n_samples,) + x.shape[2:]), samples
257257
) # (T * N, ...)
258-
shuffled_batch = jax.tree_util.tree_map(
258+
shuffled_batch = jax.tree.map(
259259
lambda x: jnp.take(x, permutation, axis=0), samples
260260
) # (T * N, ...)
261261

262262
# One epoch update over all mini-batches
263-
minibatches = jax.tree_util.tree_map(
263+
minibatches = jax.tree.map(
264264
lambda x: jnp.reshape(
265265
x, (self.hparams.num_minibatches, -1) + tuple(x.shape[1:])
266266
),

navix/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class Entity(Positionable, HasTag, HasSprite):
6666
To create an entity, use the `create` method."""
6767

6868
def __getitem__(self: T, idx) -> T:
69-
return jax.tree_util.tree_map(lambda x: x[idx], self)
69+
return jax.tree.map(lambda x: x[idx], self)
7070

7171
@property
7272
def name(self) -> str:

navix/environments/key_corridor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import jax
2525
import jax.numpy as jnp
2626
from jax import Array
27-
import jax.tree_util as jtu
2827

2928
from navix import observations, rewards, terminations
3029

@@ -116,7 +115,7 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
116115
open=jnp.asarray(0),
117116
)
118117
)
119-
doors = jtu.tree_map(lambda *x: jnp.stack(x), *doors)
118+
doors = jax.tree.map(lambda *x: jnp.stack(x), *doors)
120119

121120
entities = {
122121
"player": player[None],

navix/experiment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,10 @@ def search(hparam_set_sample):
171171
# average over seeds
172172
for i in range(len_search_set):
173173
print("Logging results for hparam set:", search_set)
174-
hparams = jax.tree_map(lambda x: x[i], search_set)
174+
hparams = jax.tree.map(lambda x: x[i], search_set)
175175
config = {**vars(self), **asdict(hparams)}
176176
wandb.init(project=self.name, config=config, group=self.group)
177-
log = jax.tree_map(lambda x: jnp.mean(x[i], axis=0), logs)
177+
log = jax.tree.map(lambda x: jnp.mean(x[i], axis=0), logs)
178178
self.agent.log_on_train_end(log)
179179
wandb.finish()
180180
logging_time = time.time() - start_time

navix/grid.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import jax
2727
import jax.numpy as jnp
2828
from jax import Array
29-
import jax.tree_util as jtu
3029
from flax import struct
3130

3231

navix/states.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from jax import Array
2323
import jax.numpy as jnp
2424
from flax import struct
25+
from dataclasses import field
2526

2627

2728
from .components import (
@@ -63,10 +64,12 @@ class Event(Positionable, HasColour):
6364
happened (Array): A boolean flag indicating whether the event happened.
6465
event_type (Array): The type of event that happened."""
6566

66-
position: Array = jnp.asarray([-1, -1], dtype=jnp.int32)
67-
colour: Array = PALETTE.UNSET
68-
happened: Array = jnp.asarray(False, dtype=jnp.bool_)
69-
event_type: Array = EventType.NONE
67+
position: Array = field(
68+
default_factory=lambda: jnp.asarray([-1, -1], dtype=jnp.int32)
69+
)
70+
colour: Array = field(default_factory=lambda: PALETTE.UNSET)
71+
happened: Array = field(default_factory=lambda: jnp.asarray(False, dtype=jnp.bool_))
72+
event_type: Array = field(default_factory=lambda: EventType.NONE)
7073

7174
def __eq__(self, other: Event) -> Array:
7275
return jnp.logical_and(
@@ -160,10 +163,10 @@ def record_goal_reached(self, goal: Goal, position: Array) -> EventsManager:
160163
def record_ball_hit(self, ball: Ball) -> EventsManager:
161164
"""Flags an event when the player is hit by a ball as happened and returns the
162165
updated events manager.
163-
166+
164167
Args:
165168
ball (Ball): The ball that hit the player.
166-
169+
167170
Returns:
168171
EventsManager: The updated events manager."""
169172
return self.replace(
@@ -178,11 +181,11 @@ def record_ball_hit(self, ball: Ball) -> EventsManager:
178181
def record_wall_hit(self, wall: Wall, position: Array) -> EventsManager:
179182
"""Flags an event when the player hits a wall as happened and returns the
180183
updated events manager.
181-
184+
182185
Args:
183186
wall (Wall): The wall the player hit.
184187
position (Array): The position of the wall in the grid.
185-
188+
186189
Returns:
187190
EventsManager: The updated events manager."""
188191
idx = jnp.where(wall.position == position, size=1)[0][0]
@@ -199,10 +202,10 @@ def record_wall_hit(self, wall: Wall, position: Array) -> EventsManager:
199202
def record_grid_hit(self, position: Array) -> EventsManager:
200203
"""Flags an event when the player hits a wall as happened and returns the
201204
updated events manager.
202-
205+
203206
Args:
204207
position (Array): The position of the wall in the grid.
205-
208+
206209
Returns:
207210
EventsManager: The updated events manager."""
208211
return self.replace(
@@ -217,11 +220,11 @@ def record_grid_hit(self, position: Array) -> EventsManager:
217220
def record_lava_fall(self, lava: Lava, position: Array) -> EventsManager:
218221
"""Flags an event when the lava falls as happened and returns the
219222
updated events manager.
220-
223+
221224
Args:
222225
lava (Lava): The lava that fell.
223226
position (Array): The position of the lava in the grid.
224-
227+
225228
Returns:
226229
EventsManager: The updated events manager."""
227230
idx = jnp.where(lava.position == position, size=1)[0][0]
@@ -238,11 +241,11 @@ def record_lava_fall(self, lava: Lava, position: Array) -> EventsManager:
238241
def record_key_pickup(self, key: Key, position: Array) -> EventsManager:
239242
"""Flags an event when the player picks up a key as happened and returns the
240243
updated events manager.
241-
244+
242245
Args:
243246
key (Key): The key the player picked up.
244247
position (Array): The position of the key in the grid.
245-
248+
246249
Returns:
247250
EventsManager: The updated events manager."""
248251
idx = jnp.where(key.position == position, size=1)[0][0]
@@ -259,11 +262,11 @@ def record_key_pickup(self, key: Key, position: Array) -> EventsManager:
259262
def record_door_opening(self, door: Door, position: Array) -> EventsManager:
260263
"""Flags an event when the player opens a door as happened and returns the
261264
updated events manager.
262-
265+
263266
Args:
264267
door (Door): The door the player opened.
265268
position (Array): The position of the door in the grid.
266-
269+
267270
Returns:
268271
EventsManager: The updated events manager."""
269272
idx = jnp.where(door.position == position, size=1)[0][0]
@@ -284,7 +287,7 @@ def record_door_unlock(self, door: Door, position: Array) -> EventsManager:
284287
Args:
285288
door (Door): The door the player unlocked.
286289
position (Array): The position of the door in the grid.
287-
290+
288291
Returns:
289292
EventsManager: The updated events manager."""
290293
idx = jnp.where(door.position == position, size=1)[0][0]
@@ -301,11 +304,11 @@ def record_door_unlock(self, door: Door, position: Array) -> EventsManager:
301304
def record_ball_pickup(self, ball: Ball, position: Array) -> EventsManager:
302305
"""Flags an event when the player picks up a ball as happened and returns the
303306
updated events manager.
304-
307+
305308
Args:
306309
ball (Ball): The ball the player picked up.
307310
position (Array): The position of the ball in the grid.
308-
311+
309312
Returns:
310313
EventsManager: The updated events manager."""
311314
idx = jnp.where(ball.position == position, size=1)[0][0]
@@ -339,21 +342,21 @@ class State(struct.PyTreeNode):
339342

340343
def get_entity(self, entity_enum: str) -> Entity:
341344
"""Get an entity from the state by its enum.
342-
345+
343346
Args:
344347
entity_enum (str): The enum of the entity to get.
345-
348+
346349
Returns:
347350
Entity: The entity from the state."""
348351
return self.entities[entity_enum]
349352

350353
def set_entity(self, entity_enum: str, entity: Entity) -> State:
351354
"""Set an entity in the state by its enum.
352-
355+
353356
Args:
354357
entity_enum (str): The enum of the entity to set.
355358
entity (Entity): The entity to set.
356-
359+
357360
Returns:
358361
State: The updated state."""
359362
self.entities[entity_enum] = entity

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ requires = ["setuptools >= 50", "setuptools-scm[toml]>=6.2", "wheel"]
99
name = "Navix"
1010
dynamic = ["version", "dependencies"]
1111
description = "Accelerated gridworld navigation with JAX for deep reinforcement learning"
12-
requires-python = ">=3.8"
12+
requires-python = ">=3.9"
1313
readme = "README.md"
1414
license = {file = "LICENSE", name = "Apache-2.0"}
1515
authors = [

0 commit comments

Comments
 (0)