Skip to content

Commit 0d07d3f

Browse files
committed
change nle language wrapper so we can set render_mode
1 parent 97f1085 commit 0d07d3f

File tree

5 files changed

+23
-12
lines changed

5 files changed

+23
-12
lines changed

balrog/config/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ envs:
4141
names: babyai-babaisai-textworld-crafter-nle-minihack # Environments to evaluate, separated by hyphens
4242
env_kwargs:
4343
seed: null # Random seed; null means a random seed is used
44+
render_mode: null
4445
nle_kwargs:
4546
character: "@" # Character representing the agent in NLE
4647
max_episode_steps: 100_000 # Max steps per episode in NLE

balrog/environments/minihack/minihack_env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def make_minihack_env(env_name, task, config, render_mode: Optional[str] = None)
2525
"tty_colors",
2626
],
2727
**minihack_kwargs,
28+
render_mode=render_mode,
2829
)
2930
if skip_more:
3031
env = AutoMore(env)

balrog/environments/nle/base.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ def step(self, action):
102102
def post_step(self, nle_obs):
103103
return self.nle_process_obs(nle_obs)
104104

105+
def render(self):
106+
mode = self.env.render_mode
107+
108+
if mode == "tiles":
109+
return self.render_tiles_from_obs(self.env.unwrapped._last_obs)
110+
elif mode == "tty_image":
111+
return self.render_tty_from_obs(self.env.unwrapped._last_obs)
112+
else:
113+
return self.env.render()
114+
105115
@property
106116
def default_action(self):
107117
if "minihack" in self.env.spec.id.lower():
@@ -113,7 +123,7 @@ def get_text_action(self, action):
113123
return self.action_str_enum_map[action]
114124

115125
def nle_process_obs(self, nle_obs):
116-
img = Image.fromarray(self.render("tiles")).convert("RGB") if self.vlm else None
126+
img = Image.fromarray(self.render_tiles_from_obs(nle_obs)).convert("RGB") if self.vlm else None
117127
text = self.nle_obs_type(nle_obs)
118128

119129
return {
@@ -131,15 +141,14 @@ def nle_obs_type(self, nle_obs):
131141
else:
132142
raise ValueError(f'"{self.prompt_mode}" is not a valid prompt mode.')
133143

134-
def render(self, mode="human"):
135-
if mode in ("tiles", "tty_image"):
136-
obs = self.env.unwrapped.last_observation
137-
key_idx = self.env.unwrapped._observation_keys.index
138-
if mode == "tiles":
139-
return rgb_render_image(obs[key_idx("glyphs")])
140-
else:
141-
return tty_render_image(obs[key_idx("tty_chars")], obs[key_idx("tty_colors")])
142-
return self.env.render(mode)
144+
def render_tiles_from_obs(self, obs):
145+
# Custom tiles rendering from latest observation
146+
key_idx = self.env.unwrapped._observation_keys.index
147+
return rgb_render_image(obs[key_idx("glyphs")])
148+
149+
def render_tty_from_obs(self, obs):
150+
key_idx = self.env.unwrapped._observation_keys.index
151+
return tty_render_image(obs[key_idx("tty_chars")], obs[key_idx("tty_colors")])
143152

144153
def get_stats(self):
145154
return self.progress.__dict__

balrog/environments/nle/nle_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def make_nle_env(env_name, task, config, render_mode: Optional[str] = None):
1414
nle_kwargs = dict(config.envs.nle_kwargs)
1515
skip_more = nle_kwargs.pop("skip_more", False)
1616
vlm = True if config.agent.max_image_history > 0 else False
17-
env = gym.make(task, **nle_kwargs)
17+
env = gym.make(task, **nle_kwargs, render_mode=render_mode)
1818
if skip_more:
1919
env = AutoMore(env)
2020

balrog/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def run_episode(self, task, agent, process_num=None, position=0, episode_idx=0):
254254
Returns:
255255
dict: Log of the episode containing statistics and results.
256256
"""
257-
env = make_env(self.env_name, task, self.config)
257+
env = make_env(self.env_name, task, self.config, render_mode=self.config.envs.env_kwargs.render_mode)
258258
agent.reset()
259259

260260
seed = self.config.envs.env_kwargs.seed

0 commit comments

Comments
 (0)