33import gymnasium as gym
44from nle import nle_language_obsv
55from nle .language_wrapper .wrappers import nle_language_wrapper
6+ from nle .nethack import USEFUL_ACTIONS
67from PIL import Image
78
89from balrog .environments import Strings
10+ from balrog .environments .nle .progress import get_progress_system
911from balrog .environments .nle .render import tty_render_image
1012from balrog .environments .nle .render_rgb import rgb_render_image
1113
1214
1315class NLELanguageWrapper (gym .Wrapper ):
1416 def __init__ (self , env , vlm = False ):
1517 super ().__init__ (env )
16-
18+ self .nle_language = nle_language_obsv .NLELanguageObsv ()
19+ self .language_action_space = self .create_action_space ()
1720 self .vlm = vlm
21+ self .done = False
22+
1823 if not vlm :
1924 self .prompt_mode = "hybrid"
2025 else :
2126 self .prompt_mode = "language"
2227
23- self .action_str_enum_map = {}
24- self .action_enum_index_map = {}
25- self .action_str_desc_map = {}
26-
27- if "minihack" in self .env .spec .id .lower ():
28- all_action_strs = [
29- action_str
30- for action_strs in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .values ()
31- for action_str in action_strs
32- ]
33- assert all (key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR ), ", " .join (
34- [key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs ]
35- )
36-
37- for action_enum in self .env .unwrapped .actions :
38- for action_str in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map [action_enum ]:
39- if action_str not in MINIHACK_ACTIONS_TO_DESCR :
40- continue
41-
42- self .action_str_enum_map [action_str ] = action_enum
43- self .action_enum_index_map [action_enum ] = self .env .unwrapped .actions .index (action_enum )
44- self .action_str_desc_map [action_str ] = MINIHACK_ACTIONS_TO_DESCR [action_str ]
45-
46- elif "nethack" in self .env .spec .id .lower ():
47- all_action_strs = [
48- action_str
49- for action_strs in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .values ()
50- for action_str in action_strs
51- ]
52- assert all (key in all_action_strs for key in NLE_ACTIONS_TO_DESCR ), ", " .join (
53- [key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs ]
54- )
55-
56- for action_enum in self .env .unwrapped .actions :
57- for action_str in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map [action_enum ]:
58- if action_str not in NLE_ACTIONS_TO_DESCR :
59- continue
60-
61- self .action_str_enum_map [action_str ] = action_enum
62- self .action_enum_index_map [action_enum ] = self .env .unwrapped .actions .index (action_enum )
63- self .action_str_desc_map [action_str ] = NLE_ACTIONS_TO_DESCR [action_str ]
64-
65- else :
66- raise ValueError (f"Unsupported environment: { self .env .spec .id } " )
67-
68- self .nle_language = nle_language_obsv .NLELanguageObsv ()
69- self .language_action_space = self .create_action_space ()
70- self .done = False
28+ self .progress = get_progress_system (self .env )
7129 self .max_steps = self .env .unwrapped ._max_episode_steps
7230
7331 def pre_reset (self ):
74- pass
32+ self . progress = get_progress_system ( self . env )
7533
7634 def reset (self , ** kwargs ):
7735 self .pre_reset ()
78- self . obs , self . info = self .env .reset (** kwargs )
36+ obs , info = self .env .reset (** kwargs )
7937
80- return self .post_reset (self . obs ), self . info
38+ return self .post_reset (obs ), info
8139
8240 def post_reset (self , nle_obs ):
8341 return self .nle_process_obs (nle_obs )
@@ -91,9 +49,13 @@ def pre_step(self, action):
9149 def step (self , action ):
9250 action = self .pre_step (action )
9351
94- self .obs , reward , term , trun , self .info = self .env .step (action )
52+ obs , reward , term , trun , info = self .env .step (action )
53+
54+ done = term or trun
55+ self .done = done if not self .done else self .done
56+ self .progress .update (obs , reward , self .done , info )
9557
96- return self .post_step (self . obs ), reward , term , trun , self . info
58+ return self .post_step (obs ), reward , term , trun , info
9759
9860 def post_step (self , nle_obs ):
9961 return self .nle_process_obs (nle_obs )
@@ -141,9 +103,54 @@ def render(self, mode="human"):
141103 return self .env .render (mode )
142104
143105 def get_stats (self ):
144- return self .info . get ( "episode_extra_stats" , {})
106+ return self .progress . __dict__
145107
146108 def create_action_space (self ):
109+ self .action_str_enum_map = {}
110+ self .action_enum_index_map = {}
111+ self .action_str_desc_map = {}
112+
113+ if "minihack" in self .env .spec .id .lower ():
114+ all_action_strs = [
115+ action_str
116+ for action_strs in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .values ()
117+ for action_str in action_strs
118+ ]
119+ assert all (key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR ), ", " .join (
120+ [key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs ]
121+ )
122+
123+ for action_enum in self .env .unwrapped .actions :
124+ for action_str in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map [action_enum ]:
125+ if action_str not in MINIHACK_ACTIONS_TO_DESCR :
126+ continue
127+
128+ self .action_str_enum_map [action_str ] = action_enum
129+ self .action_enum_index_map [action_enum ] = self .env .unwrapped .actions .index (action_enum )
130+ self .action_str_desc_map [action_str ] = MINIHACK_ACTIONS_TO_DESCR [action_str ]
131+
132+ elif "nethack" in self .env .spec .id .lower ():
133+ all_action_strs = [
134+ action_str
135+ for action_strs in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map .values ()
136+ for action_str in action_strs
137+ ]
138+ assert all (key in all_action_strs for key in NLE_ACTIONS_TO_DESCR ), ", " .join (
139+ [key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs ]
140+ )
141+
142+ for action_enum in self .env .unwrapped .actions :
143+ for action_str in nle_language_wrapper .NLELanguageWrapper .all_nle_action_map [action_enum ]:
144+ if action_str not in NLE_ACTIONS_TO_DESCR :
145+ continue
146+
147+ self .action_str_enum_map [action_str ] = action_enum
148+ self .action_enum_index_map [action_enum ] = self .env .unwrapped .actions .index (action_enum )
149+ self .action_str_desc_map [action_str ] = NLE_ACTIONS_TO_DESCR [action_str ]
150+
151+ else :
152+ raise ValueError (f"Unsupported environment: { self .env .spec .id } " )
153+
147154 all_actions = list (self .action_str_enum_map .keys ())
148155 return Strings (all_actions )
149156
0 commit comments