Skip to content

Commit 8ada2bb

Browse files
committed
use og progress
1 parent bfa6d11 commit 8ada2bb

File tree

6 files changed

+231
-105
lines changed

6 files changed

+231
-105
lines changed

balrog/environments/minihack/minihack_env.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from gymnasium import registry
55

66
import minihack # NOQA: F401
7-
from balrog.environments.minihack.minihack_progress import MiniHackProgressWrapper
87
from balrog.environments.nle import AutoMore, NLELanguageWrapper
98

109
MINIHACK_ENVS = [env_spec.id for env_spec in registry.values() if "MiniHack" in env_spec.id]
@@ -30,7 +29,6 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None)
3029
if skip_more:
3130
env = AutoMore(env)
3231

33-
env = MiniHackProgressWrapper(env, progression_on_done_only=False)
3432
env = NLELanguageWrapper(env, vlm=vlm)
3533

3634
return env

balrog/environments/minihack/minihack_progress.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

balrog/environments/nle/base.py

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,81 +3,39 @@
33
import gymnasium as gym
44
from nle import nle_language_obsv
55
from nle.language_wrapper.wrappers import nle_language_wrapper
6+
from nle.nethack import USEFUL_ACTIONS
67
from PIL import Image
78

89
from balrog.environments import Strings
10+
from balrog.environments.nle.progress import get_progress_system
911
from balrog.environments.nle.render import tty_render_image
1012
from balrog.environments.nle.render_rgb import rgb_render_image
1113

1214

1315
class NLELanguageWrapper(gym.Wrapper):
1416
def __init__(self, env, vlm=False):
1517
super().__init__(env)
16-
18+
self.nle_language = nle_language_obsv.NLELanguageObsv()
19+
self.language_action_space = self.create_action_space()
1720
self.vlm = vlm
21+
self.done = False
22+
1823
if not vlm:
1924
self.prompt_mode = "hybrid"
2025
else:
2126
self.prompt_mode = "language"
2227

23-
self.action_str_enum_map = {}
24-
self.action_enum_index_map = {}
25-
self.action_str_desc_map = {}
26-
27-
if "minihack" in self.env.spec.id.lower():
28-
all_action_strs = [
29-
action_str
30-
for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values()
31-
for action_str in action_strs
32-
]
33-
assert all(key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR), ", ".join(
34-
[key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs]
35-
)
36-
37-
for action_enum in self.env.unwrapped.actions:
38-
for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]:
39-
if action_str not in MINIHACK_ACTIONS_TO_DESCR:
40-
continue
41-
42-
self.action_str_enum_map[action_str] = action_enum
43-
self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum)
44-
self.action_str_desc_map[action_str] = MINIHACK_ACTIONS_TO_DESCR[action_str]
45-
46-
elif "nethack" in self.env.spec.id.lower():
47-
all_action_strs = [
48-
action_str
49-
for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values()
50-
for action_str in action_strs
51-
]
52-
assert all(key in all_action_strs for key in NLE_ACTIONS_TO_DESCR), ", ".join(
53-
[key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs]
54-
)
55-
56-
for action_enum in self.env.unwrapped.actions:
57-
for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]:
58-
if action_str not in NLE_ACTIONS_TO_DESCR:
59-
continue
60-
61-
self.action_str_enum_map[action_str] = action_enum
62-
self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum)
63-
self.action_str_desc_map[action_str] = NLE_ACTIONS_TO_DESCR[action_str]
64-
65-
else:
66-
raise ValueError(f"Unsupported environment: {self.env.spec.id}")
67-
68-
self.nle_language = nle_language_obsv.NLELanguageObsv()
69-
self.language_action_space = self.create_action_space()
70-
self.done = False
28+
self.progress = get_progress_system(self.env)
7129
self.max_steps = self.env.unwrapped._max_episode_steps
7230

7331
def pre_reset(self):
74-
pass
32+
self.progress = get_progress_system(self.env)
7533

7634
def reset(self, **kwargs):
7735
self.pre_reset()
78-
self.obs, self.info = self.env.reset(**kwargs)
36+
obs, info = self.env.reset(**kwargs)
7937

80-
return self.post_reset(self.obs), self.info
38+
return self.post_reset(obs), info
8139

8240
def post_reset(self, nle_obs):
8341
return self.nle_process_obs(nle_obs)
@@ -91,9 +49,13 @@ def pre_step(self, action):
9149
def step(self, action):
9250
action = self.pre_step(action)
9351

94-
self.obs, reward, term, trun, self.info = self.env.step(action)
52+
obs, reward, term, trun, info = self.env.step(action)
53+
54+
done = term or trun
55+
self.done = done if not self.done else self.done
56+
self.progress.update(obs, reward, self.done, info)
9557

96-
return self.post_step(self.obs), reward, term, trun, self.info
58+
return self.post_step(obs), reward, term, trun, info
9759

9860
def post_step(self, nle_obs):
9961
return self.nle_process_obs(nle_obs)
@@ -141,9 +103,54 @@ def render(self, mode="human"):
141103
return self.env.render(mode)
142104

143105
def get_stats(self):
144-
return self.info.get("episode_extra_stats", {})
106+
return self.progress.__dict__
145107

146108
def create_action_space(self):
109+
self.action_str_enum_map = {}
110+
self.action_enum_index_map = {}
111+
self.action_str_desc_map = {}
112+
113+
if "minihack" in self.env.spec.id.lower():
114+
all_action_strs = [
115+
action_str
116+
for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values()
117+
for action_str in action_strs
118+
]
119+
assert all(key in all_action_strs for key in MINIHACK_ACTIONS_TO_DESCR), ", ".join(
120+
[key for key in MINIHACK_ACTIONS_TO_DESCR if key not in all_action_strs]
121+
)
122+
123+
for action_enum in self.env.unwrapped.actions:
124+
for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]:
125+
if action_str not in MINIHACK_ACTIONS_TO_DESCR:
126+
continue
127+
128+
self.action_str_enum_map[action_str] = action_enum
129+
self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum)
130+
self.action_str_desc_map[action_str] = MINIHACK_ACTIONS_TO_DESCR[action_str]
131+
132+
elif "nethack" in self.env.spec.id.lower():
133+
all_action_strs = [
134+
action_str
135+
for action_strs in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map.values()
136+
for action_str in action_strs
137+
]
138+
assert all(key in all_action_strs for key in NLE_ACTIONS_TO_DESCR), ", ".join(
139+
[key for key in NLE_ACTIONS_TO_DESCR if key not in all_action_strs]
140+
)
141+
142+
for action_enum in self.env.unwrapped.actions:
143+
for action_str in nle_language_wrapper.NLELanguageWrapper.all_nle_action_map[action_enum]:
144+
if action_str not in NLE_ACTIONS_TO_DESCR:
145+
continue
146+
147+
self.action_str_enum_map[action_str] = action_enum
148+
self.action_enum_index_map[action_enum] = self.env.unwrapped.actions.index(action_enum)
149+
self.action_str_desc_map[action_str] = NLE_ACTIONS_TO_DESCR[action_str]
150+
151+
else:
152+
raise ValueError(f"Unsupported environment: {self.env.spec.id}")
153+
147154
all_actions = list(self.action_str_enum_map.keys())
148155
return Strings(all_actions)
149156

balrog/environments/nle/nle_env.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import nle # NOQA: F401
55
import nle_progress # NOQA: F401
66
from gymnasium import registry
7-
from nle_progress import NLEProgressWrapper
87

98
from balrog.environments.nle import AutoMore, NLELanguageWrapper
109

@@ -19,7 +18,6 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None):
1918
if skip_more:
2019
env = AutoMore(env)
2120

22-
env = NLEProgressWrapper(env, progression_on_done_only=False)
2321
env = NLELanguageWrapper(env, vlm=vlm)
2422

2523
return env

0 commit comments

Comments
 (0)