2222from jax import Array
2323import jax .numpy as jnp
2424from flax import struct
25+ from dataclasses import field
2526
2627
2728from .components import (
@@ -63,10 +64,12 @@ class Event(Positionable, HasColour):
6364 happened (Array): A boolean flag indicating whether the event happened.
6465 event_type (Array): The type of event that happened."""
6566
66- position : Array = jnp .asarray ([- 1 , - 1 ], dtype = jnp .int32 )
67- colour : Array = PALETTE .UNSET
68- happened : Array = jnp .asarray (False , dtype = jnp .bool_ )
69- event_type : Array = EventType .NONE
67+ position : Array = field (
68+ default_factory = lambda : jnp .asarray ([- 1 , - 1 ], dtype = jnp .int32 )
69+ )
70+ colour : Array = field (default_factory = lambda : PALETTE .UNSET )
71+ happened : Array = field (default_factory = lambda : jnp .asarray (False , dtype = jnp .bool_ ))
72+ event_type : Array = field (default_factory = lambda : EventType .NONE )
7073
7174 def __eq__ (self , other : Event ) -> Array :
7275 return jnp .logical_and (
@@ -160,10 +163,10 @@ def record_goal_reached(self, goal: Goal, position: Array) -> EventsManager:
160163 def record_ball_hit (self , ball : Ball ) -> EventsManager :
161164 """Flags an event when the player is hit by a ball as happened and returns the
162165 updated events manager.
163-
166+
164167 Args:
165168 ball (Ball): The ball that hit the player.
166-
169+
167170 Returns:
168171 EventsManager: The updated events manager."""
169172 return self .replace (
@@ -178,11 +181,11 @@ def record_ball_hit(self, ball: Ball) -> EventsManager:
178181 def record_wall_hit (self , wall : Wall , position : Array ) -> EventsManager :
179182 """Flags an event when the player hits a wall as happened and returns the
180183 updated events manager.
181-
184+
182185 Args:
183186 wall (Wall): The wall the player hit.
184187 position (Array): The position of the wall in the grid.
185-
188+
186189 Returns:
187190 EventsManager: The updated events manager."""
188191 idx = jnp .where (wall .position == position , size = 1 )[0 ][0 ]
@@ -199,10 +202,10 @@ def record_wall_hit(self, wall: Wall, position: Array) -> EventsManager:
199202 def record_grid_hit (self , position : Array ) -> EventsManager :
200203 """Flags an event when the player hits a wall as happened and returns the
201204 updated events manager.
202-
205+
203206 Args:
204207 position (Array): The position of the wall in the grid.
205-
208+
206209 Returns:
207210 EventsManager: The updated events manager."""
208211 return self .replace (
@@ -217,11 +220,11 @@ def record_grid_hit(self, position: Array) -> EventsManager:
217220 def record_lava_fall (self , lava : Lava , position : Array ) -> EventsManager :
218221 """Flags an event when the lava falls as happened and returns the
219222 updated events manager.
220-
223+
221224 Args:
222225 lava (Lava): The lava that fell.
223226 position (Array): The position of the lava in the grid.
224-
227+
225228 Returns:
226229 EventsManager: The updated events manager."""
227230 idx = jnp .where (lava .position == position , size = 1 )[0 ][0 ]
@@ -238,11 +241,11 @@ def record_lava_fall(self, lava: Lava, position: Array) -> EventsManager:
238241 def record_key_pickup (self , key : Key , position : Array ) -> EventsManager :
239242 """Flags an event when the player picks up a key as happened and returns the
240243 updated events manager.
241-
244+
242245 Args:
243246 key (Key): The key the player picked up.
244247 position (Array): The position of the key in the grid.
245-
248+
246249 Returns:
247250 EventsManager: The updated events manager."""
248251 idx = jnp .where (key .position == position , size = 1 )[0 ][0 ]
@@ -259,11 +262,11 @@ def record_key_pickup(self, key: Key, position: Array) -> EventsManager:
259262 def record_door_opening (self , door : Door , position : Array ) -> EventsManager :
260263 """Flags an event when the player opens a door as happened and returns the
261264 updated events manager.
262-
265+
263266 Args:
264267 door (Door): The door the player opened.
265268 position (Array): The position of the door in the grid.
266-
269+
267270 Returns:
268271 EventsManager: The updated events manager."""
269272 idx = jnp .where (door .position == position , size = 1 )[0 ][0 ]
@@ -284,7 +287,7 @@ def record_door_unlock(self, door: Door, position: Array) -> EventsManager:
284287 Args:
285288 door (Door): The door the player unlocked.
286289 position (Array): The position of the door in the grid.
287-
290+
288291 Returns:
289292 EventsManager: The updated events manager."""
290293 idx = jnp .where (door .position == position , size = 1 )[0 ][0 ]
@@ -301,11 +304,11 @@ def record_door_unlock(self, door: Door, position: Array) -> EventsManager:
301304 def record_ball_pickup (self , ball : Ball , position : Array ) -> EventsManager :
302305 """Flags an event when the player picks up a ball as happened and returns the
303306 updated events manager.
304-
307+
305308 Args:
306309 ball (Ball): The ball the player picked up.
307310 position (Array): The position of the ball in the grid.
308-
311+
309312 Returns:
310313 EventsManager: The updated events manager."""
311314 idx = jnp .where (ball .position == position , size = 1 )[0 ][0 ]
@@ -339,21 +342,21 @@ class State(struct.PyTreeNode):
339342
340343 def get_entity (self , entity_enum : str ) -> Entity :
341344 """Get an entity from the state by its enum.
342-
345+
343346 Args:
344347 entity_enum (str): The enum of the entity to get.
345-
348+
346349 Returns:
347350 Entity: The entity from the state."""
348351 return self .entities [entity_enum ]
349352
350353 def set_entity (self , entity_enum : str , entity : Entity ) -> State :
351354 """Set an entity in the state by its enum.
352-
355+
353356 Args:
354357 entity_enum (str): The enum of the entity to set.
355358 entity (Entity): The entity to set.
356-
359+
357360 Returns:
358361 State: The updated state."""
359362 self .entities [entity_enum ] = entity
0 commit comments