11import random
2+ from typing import Any , Dict
23
34import gymnasium as gym
45from nle import nle_language_obsv
@@ -16,7 +17,6 @@ class NLELanguageWrapper(gym.Wrapper):
1617 def __init__ (self , env , vlm = False ):
1718 super ().__init__ (env )
1819 self .nle_language = nle_language_obsv .NLELanguageObsv ()
19- self .language_action_space = self .create_action_space ()
2020 self .vlm = vlm
2121 self .done = False
2222
@@ -25,9 +25,51 @@ def __init__(self, env, vlm=False):
2525 else :
2626 self .prompt_mode = "language"
2727
28+ self ._setup_action_space ()
2829 self .progress = get_progress_system (self .env )
2930 self .max_steps = self .env .unwrapped ._max_episode_steps
3031
32+ def _setup_action_space (self ):
33+ """Determines the action set and initializes the action space."""
34+ env_id_lower = self .env .spec .id .lower ()
35+ if "minihack" in env_id_lower :
36+ actions_to_descr = MINIHACK_ACTIONS_TO_DESCR
37+ elif "nethack" in env_id_lower :
38+ actions_to_descr = NLE_ACTIONS_TO_DESCR
39+ else :
40+ raise ValueError (f"Unsupported environment: { self .env .spec .id } " )
41+
42+ self ._initialize_action_mappings (actions_to_descr )
43+ self .language_action_space = Strings (list (self .action_str_enum_map .keys ()))
44+
45+ def _initialize_action_mappings (self , actions_to_descr : Dict [str , str ]):
46+ """Builds mappings from action strings to NLE enums and indices."""
47+ self .action_str_enum_map : Dict [str , Any ] = {}
48+ self .action_enum_index_map : Dict [Any , int ] = {}
49+ self .action_str_desc_map : Dict [str , str ] = {}
50+
51+ # Pre-calculate all possible action strings from the base environment
52+ all_action_strs = {
53+ action_str
54+ for action_strs in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .values ()
55+ for action_str in action_strs
56+ }
57+
58+ # Validate that all keys in our description map are valid actions
59+ missing_keys = [key for key in actions_to_descr if key not in all_action_strs ]
60+ if missing_keys :
61+ raise KeyError (f"Action keys not found in NLE's action map: { ', ' .join (missing_keys )} " )
62+
63+ for action_enum in self .env .unwrapped .actions :
64+ action_index = self .env .unwrapped .actions .index (action_enum )
65+ possible_strs = nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .get (action_enum , [])
66+
67+ for action_str in possible_strs :
68+ if action_str in actions_to_descr :
69+ self .action_str_enum_map [action_str ] = action_enum
70+ self .action_enum_index_map [action_enum ] = action_index
71+ self .action_str_desc_map [action_str ] = actions_to_descr [action_str ]
72+
3173 def pre_reset (self ):
3274 self .progress = get_progress_system (self .env )
3375
@@ -102,55 +144,6 @@ def render(self, mode="human"):
102144 def get_stats (self ):
103145 return self .progress .__dict__
104146
105- def create_action_space (self ):
106- self .action_str_enum_map = {}
107- self .action_enum_index_map = {}
108- self .action_str_desc_map = {}
109-
110- if "minihack" in self .env .spec .id .lower ():
111- all_action_strs = [
112- action_str
113- for action_strs in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .values ()
114- for action_str in action_strs
115- ]
116- assert all (key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR ), ", " .join (
117- [key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs ]
118- )
119-
120- for action_enum in self .env .unwrapped .actions :
121- for action_str in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map [action_enum ]:
122- if action_str not in MINIHACK_ACTIONS_TO_DESCR :
123- continue
124-
125- self .action_str_enum_map [action_str ] = action_enum
126- self .action_enum_index_map [action_enum ] = self .env .unwrapped .actions .index (action_enum )
127- self .action_str_desc_map [action_str ] = MINIHACK_ACTIONS_TO_DESCR [action_str ]
128-
129- elif "nethack" in self .env .spec .id .lower ():
130- all_action_strs = [
131- action_str
132- for action_strs in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .values ()
133- for action_str in action_strs
134- ]
135- assert all (key in all_action_strs for key in NLE_ACTIONS_TO_DESCR ), ", " .join (
136- [key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs ]
137- )
138-
139- for action_enum in self .env .unwrapped .actions :
140- for action_str in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map [action_enum ]:
141- if action_str not in NLE_ACTIONS_TO_DESCR :
142- continue
143-
144- self .action_str_enum_map [action_str ] = action_enum
145- self .action_enum_index_map [action_enum ] = self .env .unwrapped .actions .index (action_enum )
146- self .action_str_desc_map [action_str ] = NLE_ACTIONS_TO_DESCR [action_str ]
147-
148- else :
149- raise ValueError (f"Unsupported environment: { self .env .spec .id } " )
150-
151- all_actions = list (self .action_str_enum_map .keys ())
152- return Strings (all_actions )
153-
154147 def ascii_render (self , chars ):
155148 rows , cols = chars .shape
156149 result = ""
0 commit comments