diff --git a/navix/environments/environment.py b/navix/environments/environment.py index 44757d5..6a71f86 100644 --- a/navix/environments/environment.py +++ b/navix/environments/environment.py @@ -31,6 +31,14 @@ from ..states import State from ..actions import DEFAULT_ACTION_SET from ..spaces import Space, Discrete, Continuous +from ..entities import EntityIds + +# Calculate maximum entity ID once at module level for efficiency +# Use vars() to get only class attributes (not inherited ones) and filter for Array instances +_MAX_ENTITY_ID = max( + int(value) for value in vars(EntityIds).values() if isinstance(value, Array) +) +_MAX_CATEGORICAL_VALUE = _MAX_ENTITY_ID + 1 class StepType(struct.PyTreeNode): @@ -225,10 +233,15 @@ def _get_obs_space_from_fn( shape=(), minimum=jnp.asarray(0.0), maximum=jnp.asarray(0.0) ) elif observation_fn == observations.categorical: - return Discrete.create(n_elements=9, shape=(height, width)) + return Discrete.create( + n_elements=_MAX_CATEGORICAL_VALUE, shape=(height, width) + ) elif observation_fn == observations.categorical_first_person: radius = observations.RADIUS - return Discrete.create(n_elements=9, shape=(radius * 2 + 1, radius * 2 + 1)) + return Discrete.create( + n_elements=_MAX_CATEGORICAL_VALUE, + shape=(radius * 2 + 1, radius * 2 + 1), + ) elif observation_fn == observations.rgb: return Discrete.create( 256, @@ -244,7 +257,7 @@ def _get_obs_space_from_fn( ) elif observation_fn == observations.symbolic: return Discrete.create( - n_elements=9, + n_elements=_MAX_CATEGORICAL_VALUE, shape=(height, width, 3), dtype=jnp.uint8, )