Skip to content

Commit 393505a

Browse files
authored
Optional additional info in crafter observation (#55)
* fix crafter observations * display only edge of the water * set legacy crafter rendering as default * add comments * reset defaults * add scipy * quick fix * fix crafter bug
1 parent 4dfee35 commit 393505a

File tree

4 files changed

+141
-31
lines changed

4 files changed

+141
-31
lines changed

balrog/config/config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ envs:
6666
size: [256, 256] # Image size in Crafter
6767
reward: True
6868
seed: null
69-
max_episode_steps: 2000
69+
max_episode_steps: 2000
70+
unique_items: True # False
71+
precise_location: False # True
72+
skip_items: [] # ["grass", "sand", "path"]
73+
edge_only_items: [] # ["water", "lava"]
7074
textworld_kwargs:
7175
objective: True
7276
description: True
Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
from typing import Optional
22

33
import crafter
4-
54
from balrog.environments.crafter import CrafterLanguageWrapper
65
from balrog.environments.wrappers import GymV21CompatibilityV0
76

87

98
def make_crafter_env(env_name, task, config, render_mode: Optional[str] = None):
109
crafter_kwargs = dict(config.envs.crafter_kwargs)
1110
max_episode_steps = crafter_kwargs.pop("max_episode_steps", 2)
11+
unique_items = crafter_kwargs.pop("unique_items", True)
12+
precise_location = crafter_kwargs.pop("precise_location", False)
13+
skip_items = crafter_kwargs.pop("skip_items", [])
14+
edge_only_items = crafter_kwargs.pop("edge_only_items", [])
1215

1316
for param in ["area", "view", "size"]:
1417
if param in crafter_kwargs:
1518
crafter_kwargs[param] = tuple(crafter_kwargs[param])
1619

1720
env = crafter.Env(**crafter_kwargs)
18-
env = CrafterLanguageWrapper(env, task, max_episode_steps=max_episode_steps)
21+
env = CrafterLanguageWrapper(
22+
env,
23+
task,
24+
max_episode_steps=max_episode_steps,
25+
unique_items=unique_items,
26+
precise_location=precise_location,
27+
skip_items=skip_items,
28+
edge_only_items=edge_only_items,
29+
)
1930
env = GymV21CompatibilityV0(env=env, render_mode=render_mode)
2031

2132
return env

balrog/environments/crafter/env.py

Lines changed: 122 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import itertools
2+
import re
3+
from collections import defaultdict
24

35
import crafter
46
import gym
57
import numpy as np
68
from PIL import Image
9+
from scipy import ndimage
710

811
from balrog.environments import Strings
912

@@ -77,7 +80,29 @@ def rotation_matrix(v1, v2):
7780
return rotation_matrix
7881

7982

80-
def describe_loc(ref, P):
83+
def describe_loc_precise(ref, P):
84+
"""
85+
Describe the location of P relative to ref.
86+
Example: `1 step south and 4 steps west`
87+
"""
88+
desc = []
89+
90+
def distange_to_string(distance, direction):
91+
return f"{abs(distance)} step{'s' if abs(distance) > 1 else ''} {direction}"
92+
93+
if ref[1] > P[1]:
94+
desc.append(distange_to_string(ref[1] - P[1], "north"))
95+
elif ref[1] < P[1]:
96+
desc.append(distange_to_string(ref[1] - P[1], "south"))
97+
if ref[0] > P[0]:
98+
desc.append(distange_to_string(ref[0] - P[0], "west"))
99+
elif ref[0] < P[0]:
100+
desc.append(distange_to_string(ref[0] - P[0], "east"))
101+
102+
return " and ".join(desc) if desc else "at your location"
103+
104+
105+
def describe_loc_old(ref, P):
81106
desc = []
82107
if ref[1] > P[1]:
83108
desc.append("north")
@@ -88,53 +113,96 @@ def describe_loc(ref, P):
88113
elif ref[0] < P[0]:
89114
desc.append("east")
90115

91-
return "-".join(desc)
116+
distance = abs(ref[1] - P[1]) + abs(ref[0] - P[0])
117+
distance_str = f"{distance} step{'s' if distance > 1 else ''} to your {'-'.join(desc)}"
118+
119+
return distance_str
92120

93121

94-
def describe_env(info):
122+
def get_edge_items(semantic, item_idx):
123+
item_mask = semantic == item_idx
124+
not_item_mask = semantic != item_idx
125+
item_edge = ndimage.binary_dilation(not_item_mask) & item_mask
126+
return item_edge
127+
128+
129+
def describe_env(
130+
info,
131+
unique_items=True,
132+
precise_location=False,
133+
skip_items=[],
134+
edge_only_items=[],
135+
):
95136
assert info["semantic"][info["player_pos"][0], info["player_pos"][1]] == player_idx
96137
semantic = info["semantic"][
97138
info["player_pos"][0] - info["view"][0] // 2 : info["player_pos"][0] + info["view"][0] // 2 + 1,
98139
info["player_pos"][1] - info["view"][1] // 2 + 1 : info["player_pos"][1] + info["view"][1] // 2,
99140
]
100141
center = np.array([info["view"][0] // 2, info["view"][1] // 2 - 1])
101142
result = ""
102-
x = np.arange(semantic.shape[1])
103-
y = np.arange(semantic.shape[0])
104-
x1, y1 = np.meshgrid(x, y)
105-
loc = np.stack((y1, x1), axis=-1)
106-
dist = np.absolute(center - loc).sum(axis=-1)
143+
describe_loc = describe_loc_precise if precise_location else describe_loc_old
107144
obj_info_list = []
108145

109146
facing = info["player_facing"]
110-
max_y, max_x = semantic.shape
147+
max_x, max_y = semantic.shape
111148
target_x = center[0] + facing[0]
112149
target_y = center[1] + facing[1]
113150

114151
if 0 <= target_x < max_x and 0 <= target_y < max_y:
115152
target_id = semantic[int(target_x), int(target_y)]
116153
target_item = id_to_item[target_id]
154+
155+
# skip grass, sand or path so obs here, since we are not displaying them
156+
if target_id in [id_to_item.index(o) for o in skip_items]:
157+
target_item = "nothing"
158+
117159
obs = "You face {} at your front.".format(target_item)
118160
else:
119161
obs = "You face nothing at your front."
120162

121-
for idx in np.unique(semantic):
122-
if idx == player_idx:
123-
continue
124-
125-
smallest = np.unravel_index(np.argmin(np.where(semantic == idx, dist, np.inf)), semantic.shape)
126-
obj_info_list.append(
127-
(
128-
id_to_item[idx],
129-
dist[smallest],
130-
describe_loc(np.array([0, 0]), smallest - center),
163+
# Edge detection
164+
edge_masks = {}
165+
for item_name in edge_only_items:
166+
item_idx = id_to_item.index(item_name)
167+
edge_masks[item_idx] = get_edge_items(semantic, item_idx)
168+
169+
for i in range(semantic.shape[0]):
170+
for j in range(semantic.shape[1]):
171+
idx = semantic[i, j]
172+
if idx == player_idx:
173+
continue
174+
175+
# only display the edge of items that are in edge_only_items
176+
if idx in edge_masks and not edge_masks[idx][i, j]:
177+
continue
178+
179+
# skip grass, sand or path so obs is not too long
180+
if idx in [id_to_item.index(o) for o in skip_items]:
181+
continue
182+
183+
obj_info_list.append((id_to_item[idx], describe_loc(np.array([0, 0]), np.array([i, j]) - center)))
184+
185+
def extract_numbers(s):
186+
"""Extract all numbers from a string."""
187+
return [int(num) for num in re.findall(r"\d+", s)]
188+
189+
# filter out items, so we only display closest item of each type
190+
if unique_items:
191+
closest_obj_info_list = defaultdict(str)
192+
for item_name, loc in obj_info_list:
193+
loc_dist = sum(extract_numbers(loc))
194+
current_dist = (
195+
sum(extract_numbers(closest_obj_info_list[item_name]))
196+
if closest_obj_info_list[item_name]
197+
else float("inf")
131198
)
132-
)
199+
200+
if loc_dist < current_dist:
201+
closest_obj_info_list[item_name] = loc
202+
obj_info_list = [(name, loc) for name, loc in closest_obj_info_list.items()]
133203

134204
if len(obj_info_list) > 0:
135-
status_str = "You see:\n{}".format(
136-
"\n".join(["- {} {} steps to your {}".format(name, dist, loc) for name, dist, loc in obj_info_list])
137-
)
205+
status_str = "You see:\n{}".format("\n".join(["- {} {}".format(name, loc) for name, loc in obj_info_list]))
138206
else:
139207
status_str = "You see nothing away from you."
140208
result += status_str + "\n\n"
@@ -167,19 +235,30 @@ def describe_status(info):
167235
return ""
168236

169237

170-
def describe_frame(info):
238+
def describe_frame(
239+
info,
240+
unique_items=True,
241+
precise_location=False,
242+
skip_items=[],
243+
edge_only_items=[],
244+
):
171245
try:
172246
result = ""
173247

174248
result += describe_status(info)
175249
result += "\n\n"
176-
result += describe_env(info)
250+
result += describe_env(
251+
info,
252+
unique_items=unique_items,
253+
precise_location=precise_location,
254+
skip_items=skip_items,
255+
edge_only_items=edge_only_items,
256+
)
177257
result += "\n\n"
178258

179259
return result.strip(), describe_inventory(info)
180260
except Exception:
181-
breakpoint()
182-
return "Error, you are out of the map."
261+
return "Error, you are out of the map.", describe_inventory(info)
183262

184263

185264
class CrafterLanguageWrapper(gym.Wrapper):
@@ -191,6 +270,10 @@ def __init__(
191270
env,
192271
task="",
193272
max_episode_steps=2,
273+
unique_items=True,
274+
precise_location=False,
275+
skip_items=[],
276+
edge_only_items=[],
194277
):
195278
super().__init__(env)
196279
self.score_tracker = 0
@@ -199,6 +282,11 @@ def __init__(
199282
self.max_steps = max_episode_steps
200283
self.achievements = None
201284

285+
self.unique_items = unique_items
286+
self.precise_location = precise_location
287+
self.skip_items = skip_items
288+
self.edge_only_items = edge_only_items
289+
202290
def get_text_action(self, action):
203291
return self.language_action_space._values[action]
204292

@@ -232,7 +320,13 @@ def step(self, action):
232320

233321
def process_obs(self, obs, info):
234322
img = Image.fromarray(self.env.render()).convert("RGB")
235-
long_term_context, short_term_context = describe_frame(info)
323+
long_term_context, short_term_context = describe_frame(
324+
info,
325+
unique_items=self.unique_items,
326+
precise_location=self.precise_location,
327+
skip_items=self.skip_items,
328+
edge_only_items=self.edge_only_items,
329+
)
236330

237331
return {
238332
"text": {

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"opencv-python-headless",
2727
"wandb",
2828
"pytest",
29+
"scipy",
2930
"crafter",
3031
"gym==0.23",
3132
"requests",

0 commit comments

Comments
 (0)