Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions navix/environments/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
from ..states import State
from ..actions import DEFAULT_ACTION_SET
from ..spaces import Space, Discrete, Continuous
from ..entities import EntityIds

# Calculate maximum entity ID once at module level for efficiency
# Use vars() to get only class attributes (not inherited ones) and filter for Array instances
_MAX_ENTITY_ID = max(
int(value) for value in vars(EntityIds).values() if isinstance(value, Array)
)
_MAX_CATEGORICAL_VALUE = _MAX_ENTITY_ID + 1


class StepType(struct.PyTreeNode):
Expand Down Expand Up @@ -225,10 +233,15 @@ def _get_obs_space_from_fn(
shape=(), minimum=jnp.asarray(0.0), maximum=jnp.asarray(0.0)
)
elif observation_fn == observations.categorical:
return Discrete.create(n_elements=9, shape=(height, width))
return Discrete.create(
n_elements=_MAX_CATEGORICAL_VALUE, shape=(height, width)
)
elif observation_fn == observations.categorical_first_person:
radius = observations.RADIUS
return Discrete.create(n_elements=9, shape=(radius * 2 + 1, radius * 2 + 1))
return Discrete.create(
n_elements=_MAX_CATEGORICAL_VALUE,
shape=(radius * 2 + 1, radius * 2 + 1),
)
elif observation_fn == observations.rgb:
return Discrete.create(
256,
Expand All @@ -244,7 +257,7 @@ def _get_obs_space_from_fn(
)
elif observation_fn == observations.symbolic:
return Discrete.create(
n_elements=9,
n_elements=_MAX_CATEGORICAL_VALUE,
shape=(height, width, 3),
dtype=jnp.uint8,
)
Expand Down
Loading