Skip to content

Commit 97f1085

Browse files
committed
small refactor for action mappings
1 parent 33e9b72 commit 97f1085

File tree

1 file changed

+43
-50
lines changed

1 file changed

+43
-50
lines changed

balrog/environments/nle/base.py

Lines changed: 43 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from typing import Any, Dict
23

34
import gymnasium as gym
45
from nle import nle_language_obsv
@@ -16,7 +17,6 @@ class NLELanguageWrapper(gym.Wrapper):
1617
def __init__(self, env, vlm=False):
1718
super().__init__(env)
1819
self.nle_language = nle_language_obsv.NLELanguageObsv()
19-
self.language_action_space = self.create_action_space()
2020
self.vlm = vlm
2121
self.done = False
2222

@@ -25,9 +25,51 @@ def __init__(self, env, vlm=False):
2525
else:
2626
self.prompt_mode = "language"
2727

28+
self._setup_action_space()
2829
self.progress = get_progress_system(self.env)
2930
self.max_steps = self.env.unwrapped._max_episode_steps
3031

32+
def _setup_action_space(self):
33+
"""Determines the action set and initializes the action space."""
34+
env_id_lower = self.env.spec.id.lower()
35+
if "minihack" in env_id_lower:
36+
actions_to_descr = MINIHACK_ACTIONS_TO_DESCR
37+
elif "nethack" in env_id_lower:
38+
actions_to_descr = NLE_ACTIONS_TO_DESCR
39+
else:
40+
raise ValueError(f"Unsupported environment: {self.env.spec.id}")
41+
42+
self._initialize_action_mappings(actions_to_descr)
43+
self.language_action_space = Strings(list(self.action_str_enum_map.keys()))
44+
45+
def _initialize_action_mappings(self, actions_to_descr: Dict[str, str]):
46+
"""Builds mappings from action strings to NLE enums and indices."""
47+
self.action_str_enum_map: Dict[str, Any] = {}
48+
self.action_enum_index_map: Dict[Any, int] = {}
49+
self.action_str_desc_map: Dict[str, str] = {}
50+
51+
# Pre-calculate all possible action strings from the base environment
52+
all_action_strs = {
53+
action_str
54+
for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values()
55+
for action_str in action_strs
56+
}
57+
58+
# Validate that all keys in our description map are valid actions
59+
missing_keys = [key for key in actions_to_descr if key not in all_action_strs]
60+
if missing_keys:
61+
raise KeyError(f"Action keys not found in NLE's action map: {', '.join(missing_keys)}")
62+
63+
for action_enum in self.env.unwrapped.actions:
64+
action_index = self.env.unwrapped.actions.index(action_enum)
65+
possible_strs = nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.get(action_enum, [])
66+
67+
for action_str in possible_strs:
68+
if action_str in actions_to_descr:
69+
self.action_str_enum_map[action_str] = action_enum
70+
self.action_enum_index_map[action_enum] = action_index
71+
self.action_str_desc_map[action_str] = actions_to_descr[action_str]
72+
3173
def pre_reset(self):
3274
self.progress = get_progress_system(self.env)
3375

@@ -102,55 +144,6 @@ def render(self, mode="human"):
102144
def get_stats(self):
103145
return self.progress.__dict__
104146

105-
def create_action_space(self):
106-
self.action_str_enum_map = {}
107-
self.action_enum_index_map = {}
108-
self.action_str_desc_map = {}
109-
110-
if "minihack" in self.env.spec.id.lower():
111-
all_action_strs = [
112-
action_str
113-
for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values()
114-
for action_str in action_strs
115-
]
116-
assert all(key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR), ", ".join(
117-
[key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs]
118-
)
119-
120-
for action_enum in self.env.unwrapped.actions:
121-
for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]:
122-
if action_str not in MINIHACK_ACTIONS_TO_DESCR:
123-
continue
124-
125-
self.action_str_enum_map[action_str] = action_enum
126-
self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum)
127-
self.action_str_desc_map[action_str] = MINIHACK_ACTIONS_TO_DESCR[action_str]
128-
129-
elif "nethack" in self.env.spec.id.lower():
130-
all_action_strs = [
131-
action_str
132-
for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values()
133-
for action_str in action_strs
134-
]
135-
assert all(key in all_action_strs for key in NLE_ACTIONS_TO_DESCR), ", ".join(
136-
[key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs]
137-
)
138-
139-
for action_enum in self.env.unwrapped.actions:
140-
for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]:
141-
if action_str not in NLE_ACTIONS_TO_DESCR:
142-
continue
143-
144-
self.action_str_enum_map[action_str] = action_enum
145-
self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum)
146-
self.action_str_desc_map[action_str] = NLE_ACTIONS_TO_DESCR[action_str]
147-
148-
else:
149-
raise ValueError(f"Unsupported environment: {self.env.spec.id}")
150-
151-
all_actions = list(self.action_str_enum_map.keys())
152-
return Strings(all_actions)
153-
154147
def ascii_render(self, chars):
155148
rows, cols = chars.shape
156149
result = ""

0 commit comments

Comments
 (0)