11import itertools
2+ import re
3+ from collections import defaultdict
24
35import crafter
46import gym
57import numpy as np
68from PIL import Image
9+ from scipy import ndimage
710
811from balrog .environments import Strings
912
@@ -77,7 +80,29 @@ def rotation_matrix(v1, v2):
7780 return rotation_matrix
7881
7982
80- def describe_loc (ref , P ):
83+ def describe_loc_precise (ref , P ):
84+ """
85+ Describe the location of P relative to ref.
86+ Example: `1 step south and 4 steps west`
87+ """
88+ desc = []
89+
90+ def distange_to_string (distance , direction ):
91+ return f"{ abs (distance )} step{ 's' if abs (distance ) > 1 else '' } { direction } "
92+
93+ if ref [1 ] > P [1 ]:
94+ desc .append (distange_to_string (ref [1 ] - P [1 ], "north" ))
95+ elif ref [1 ] < P [1 ]:
96+ desc .append (distange_to_string (ref [1 ] - P [1 ], "south" ))
97+ if ref [0 ] > P [0 ]:
98+ desc .append (distange_to_string (ref [0 ] - P [0 ], "west" ))
99+ elif ref [0 ] < P [0 ]:
100+ desc .append (distange_to_string (ref [0 ] - P [0 ], "east" ))
101+
102+ return " and " .join (desc ) if desc else "at your location"
103+
104+
105+ def describe_loc_old (ref , P ):
81106 desc = []
82107 if ref [1 ] > P [1 ]:
83108 desc .append ("north" )
@@ -88,53 +113,96 @@ def describe_loc(ref, P):
88113 elif ref [0 ] < P [0 ]:
89114 desc .append ("east" )
90115
91- return "-" .join (desc )
116+ distance = abs (ref [1 ] - P [1 ]) + abs (ref [0 ] - P [0 ])
117+ distance_str = f"{ distance } step{ 's' if distance > 1 else '' } to your { '-' .join (desc )} "
118+
119+ return distance_str
92120
93121
94- def describe_env (info ):
122+ def get_edge_items (semantic , item_idx ):
123+ item_mask = semantic == item_idx
124+ not_item_mask = semantic != item_idx
125+ item_edge = ndimage .binary_dilation (not_item_mask ) & item_mask
126+ return item_edge
127+
128+
129+ def describe_env (
130+ info ,
131+ unique_items = True ,
132+ precise_location = False ,
133+ skip_items = [],
134+ edge_only_items = [],
135+ ):
95136 assert info ["semantic" ][info ["player_pos" ][0 ], info ["player_pos" ][1 ]] == player_idx
96137 semantic = info ["semantic" ][
97138 info ["player_pos" ][0 ] - info ["view" ][0 ] // 2 : info ["player_pos" ][0 ] + info ["view" ][0 ] // 2 + 1 ,
98139 info ["player_pos" ][1 ] - info ["view" ][1 ] // 2 + 1 : info ["player_pos" ][1 ] + info ["view" ][1 ] // 2 ,
99140 ]
100141 center = np .array ([info ["view" ][0 ] // 2 , info ["view" ][1 ] // 2 - 1 ])
101142 result = ""
102- x = np .arange (semantic .shape [1 ])
103- y = np .arange (semantic .shape [0 ])
104- x1 , y1 = np .meshgrid (x , y )
105- loc = np .stack ((y1 , x1 ), axis = - 1 )
106- dist = np .absolute (center - loc ).sum (axis = - 1 )
143+ describe_loc = describe_loc_precise if precise_location else describe_loc_old
107144 obj_info_list = []
108145
109146 facing = info ["player_facing" ]
110- max_y , max_x = semantic .shape
147+ max_x , max_y = semantic .shape
111148 target_x = center [0 ] + facing [0 ]
112149 target_y = center [1 ] + facing [1 ]
113150
114151 if 0 <= target_x < max_x and 0 <= target_y < max_y :
115152 target_id = semantic [int (target_x ), int (target_y )]
116153 target_item = id_to_item [target_id ]
154+
155+ # skip grass, sand or path so obs here, since we are not displaying them
156+ if target_id in [id_to_item .index (o ) for o in skip_items ]:
157+ target_item = "nothing"
158+
117159 obs = "You face {} at your front." .format (target_item )
118160 else :
119161 obs = "You face nothing at your front."
120162
121- for idx in np .unique (semantic ):
122- if idx == player_idx :
123- continue
124-
125- smallest = np .unravel_index (np .argmin (np .where (semantic == idx , dist , np .inf )), semantic .shape )
126- obj_info_list .append (
127- (
128- id_to_item [idx ],
129- dist [smallest ],
130- describe_loc (np .array ([0 , 0 ]), smallest - center ),
163+ # Edge detection
164+ edge_masks = {}
165+ for item_name in edge_only_items :
166+ item_idx = id_to_item .index (item_name )
167+ edge_masks [item_idx ] = get_edge_items (semantic , item_idx )
168+
169+ for i in range (semantic .shape [0 ]):
170+ for j in range (semantic .shape [1 ]):
171+ idx = semantic [i , j ]
172+ if idx == player_idx :
173+ continue
174+
175+ # only display the edge of items that are in edge_only_items
176+ if idx in edge_masks and not edge_masks [idx ][i , j ]:
177+ continue
178+
179+ # skip grass, sand or path so obs is not too long
180+ if idx in [id_to_item .index (o ) for o in skip_items ]:
181+ continue
182+
183+ obj_info_list .append ((id_to_item [idx ], describe_loc (np .array ([0 , 0 ]), np .array ([i , j ]) - center )))
184+
185+ def extract_numbers (s ):
186+ """Extract all numbers from a string."""
187+ return [int (num ) for num in re .findall (r"\d+" , s )]
188+
189+ # filter out items, so we only display closest item of each type
190+ if unique_items :
191+ closest_obj_info_list = defaultdict (str )
192+ for item_name , loc in obj_info_list :
193+ loc_dist = sum (extract_numbers (loc ))
194+ current_dist = (
195+ sum (extract_numbers (closest_obj_info_list [item_name ]))
196+ if closest_obj_info_list [item_name ]
197+ else float ("inf" )
131198 )
132- )
199+
200+ if loc_dist < current_dist :
201+ closest_obj_info_list [item_name ] = loc
202+ obj_info_list = [(name , loc ) for name , loc in closest_obj_info_list .items ()]
133203
134204 if len (obj_info_list ) > 0 :
135- status_str = "You see:\n {}" .format (
136- "\n " .join (["- {} {} steps to your {}" .format (name , dist , loc ) for name , dist , loc in obj_info_list ])
137- )
205+ status_str = "You see:\n {}" .format ("\n " .join (["- {} {}" .format (name , loc ) for name , loc in obj_info_list ]))
138206 else :
139207 status_str = "You see nothing away from you."
140208 result += status_str + "\n \n "
@@ -167,19 +235,30 @@ def describe_status(info):
167235 return ""
168236
169237
170- def describe_frame (info ):
238+ def describe_frame (
239+ info ,
240+ unique_items = True ,
241+ precise_location = False ,
242+ skip_items = [],
243+ edge_only_items = [],
244+ ):
171245 try :
172246 result = ""
173247
174248 result += describe_status (info )
175249 result += "\n \n "
176- result += describe_env (info )
250+ result += describe_env (
251+ info ,
252+ unique_items = unique_items ,
253+ precise_location = precise_location ,
254+ skip_items = skip_items ,
255+ edge_only_items = edge_only_items ,
256+ )
177257 result += "\n \n "
178258
179259 return result .strip (), describe_inventory (info )
180260 except Exception :
181- breakpoint ()
182- return "Error, you are out of the map."
261+ return "Error, you are out of the map." , describe_inventory (info )
183262
184263
185264class CrafterLanguageWrapper (gym .Wrapper ):
@@ -191,6 +270,10 @@ def __init__(
191270 env ,
192271 task = "" ,
193272 max_episode_steps = 2 ,
273+ unique_items = True ,
274+ precise_location = False ,
275+ skip_items = [],
276+ edge_only_items = [],
194277 ):
195278 super ().__init__ (env )
196279 self .score_tracker = 0
@@ -199,6 +282,11 @@ def __init__(
199282 self .max_steps = max_episode_steps
200283 self .achievements = None
201284
285+ self .unique_items = unique_items
286+ self .precise_location = precise_location
287+ self .skip_items = skip_items
288+ self .edge_only_items = edge_only_items
289+
202290 def get_text_action (self , action ):
203291 return self .language_action_space ._values [action ]
204292
@@ -232,7 +320,13 @@ def step(self, action):
232320
233321 def process_obs (self , obs , info ):
234322 img = Image .fromarray (self .env .render ()).convert ("RGB" )
235- long_term_context , short_term_context = describe_frame (info )
323+ long_term_context , short_term_context = describe_frame (
324+ info ,
325+ unique_items = self .unique_items ,
326+ precise_location = self .precise_location ,
327+ skip_items = self .skip_items ,
328+ edge_only_items = self .edge_only_items ,
329+ )
236330
237331 return {
238332 "text" : {
0 commit comments