diff --git a/balrog/config/config.yaml b/balrog/config/config.yaml index fa8dd367..89d5ed72 100644 --- a/balrog/config/config.yaml +++ b/balrog/config/config.yaml @@ -66,7 +66,11 @@ envs: size: [256, 256] # Image size in Crafter reward: True seed: null - max_episode_steps: 2000 + max_episode_steps: 2000 + unique_items: True # False + precise_location: False # True + skip_items: [] # ["grass", "sand", "path"] + edge_only_items: [] # ["water", "lava"] textworld_kwargs: objective: True description: True diff --git a/balrog/environments/crafter/crafter_env.py b/balrog/environments/crafter/crafter_env.py index c4aa5733..765b07ca 100644 --- a/balrog/environments/crafter/crafter_env.py +++ b/balrog/environments/crafter/crafter_env.py @@ -1,7 +1,6 @@ from typing import Optional import crafter - from balrog.environments.crafter import CrafterLanguageWrapper from balrog.environments.wrappers import GymV21CompatibilityV0 @@ -9,13 +8,25 @@ def make_crafter_env(env_name, task, config, render_mode: Optional[str] = None): crafter_kwargs = dict(config.envs.crafter_kwargs) max_episode_steps = crafter_kwargs.pop("max_episode_steps", 2) + unique_items = crafter_kwargs.pop("unique_items", True) + precise_location = crafter_kwargs.pop("precise_location", False) + skip_items = crafter_kwargs.pop("skip_items", []) + edge_only_items = crafter_kwargs.pop("edge_only_items", []) for param in ["area", "view", "size"]: if param in crafter_kwargs: crafter_kwargs[param] = tuple(crafter_kwargs[param]) env = crafter.Env(**crafter_kwargs) - env = CrafterLanguageWrapper(env, task, max_episode_steps=max_episode_steps) + env = CrafterLanguageWrapper( + env, + task, + max_episode_steps=max_episode_steps, + unique_items=unique_items, + precise_location=precise_location, + skip_items=skip_items, + edge_only_items=edge_only_items, + ) env = GymV21CompatibilityV0(env=env, render_mode=render_mode) return env diff --git a/balrog/environments/crafter/env.py b/balrog/environments/crafter/env.py index 2ca72847..f783dfc5 100644 --- a/balrog/environments/crafter/env.py +++ b/balrog/environments/crafter/env.py @@ -1,9 +1,12 @@ import itertools +import re +from collections import defaultdict import crafter import gym import numpy as np from PIL import Image +from scipy import ndimage from balrog.environments import Strings @@ -77,7 +80,29 @@ def rotation_matrix(v1, v2): return rotation_matrix -def describe_loc(ref, P): +def describe_loc_precise(ref, P): + """ + Describe the location of P relative to ref. + Example: `1 step south and 4 steps west` + """ + desc = [] + + def distange_to_string(distance, direction): + return f"{abs(distance)} step{'s' if abs(distance) > 1 else ''} {direction}" + + if ref[1] > P[1]: + desc.append(distange_to_string(ref[1] - P[1], "north")) + elif ref[1] < P[1]: + desc.append(distange_to_string(ref[1] - P[1], "south")) + if ref[0] > P[0]: + desc.append(distange_to_string(ref[0] - P[0], "west")) + elif ref[0] < P[0]: + desc.append(distange_to_string(ref[0] - P[0], "east")) + + return " and ".join(desc) if desc else "at your location" + + +def describe_loc_old(ref, P): desc = [] if ref[1] > P[1]: desc.append("north") @@ -88,10 +113,26 @@ def describe_loc(ref, P): elif ref[0] < P[0]: desc.append("east") - return "-".join(desc) + distance = abs(ref[1] - P[1]) + abs(ref[0] - P[0]) + distance_str = f"{distance} step{'s' if distance > 1 else ''} to your {'-'.join(desc)}" + + return distance_str -def describe_env(info): +def get_edge_items(semantic, item_idx): + item_mask = semantic == item_idx + not_item_mask = semantic != item_idx + item_edge = ndimage.binary_dilation(not_item_mask) & item_mask + return item_edge + + +def describe_env( + info, + unique_items=True, + precise_location=False, + skip_items=[], + edge_only_items=[], +): assert info["semantic"][info["player_pos"][0], info["player_pos"][1]] == player_idx semantic = info["semantic"][ info["player_pos"][0] - info["view"][0] // 2 : info["player_pos"][0] + info["view"][0] // 2 + 1, @@ -99,42 +140,69 @@ def describe_env(info): ] center = np.array([info["view"][0] // 2, info["view"][1] // 2 - 1]) result = "" - x = np.arange(semantic.shape[1]) - y = np.arange(semantic.shape[0]) - x1, y1 = np.meshgrid(x, y) - loc = np.stack((y1, x1), axis=-1) - dist = np.absolute(center - loc).sum(axis=-1) + describe_loc = describe_loc_precise if precise_location else describe_loc_old obj_info_list = [] facing = info["player_facing"] - max_y, max_x = semantic.shape + max_x, max_y = semantic.shape target_x = center[0] + facing[0] target_y = center[1] + facing[1] if 0 <= target_x < max_x and 0 <= target_y < max_y: target_id = semantic[int(target_x), int(target_y)] target_item = id_to_item[target_id] + + # skip grass, sand or path so obs here, since we are not displaying them + if target_id in [id_to_item.index(o) for o in skip_items]: + target_item = "nothing" + obs = "You face {} at your front.".format(target_item) else: obs = "You face nothing at your front." - for idx in np.unique(semantic): - if idx == player_idx: - continue - - smallest = np.unravel_index(np.argmin(np.where(semantic == idx, dist, np.inf)), semantic.shape) - obj_info_list.append( - ( - id_to_item[idx], - dist[smallest], - describe_loc(np.array([0, 0]), smallest - center), + # Edge detection + edge_masks = {} + for item_name in edge_only_items: + item_idx = id_to_item.index(item_name) + edge_masks[item_idx] = get_edge_items(semantic, item_idx) + + for i in range(semantic.shape[0]): + for j in range(semantic.shape[1]): + idx = semantic[i, j] + if idx == player_idx: + continue + + # only display the edge of items that are in edge_only_items + if idx in edge_masks and not edge_masks[idx][i, j]: + continue + + # skip grass, sand or path so obs is not too long + if idx in [id_to_item.index(o) for o in skip_items]: + continue + + obj_info_list.append((id_to_item[idx], describe_loc(np.array([0, 0]), np.array([i, j]) - center))) + + def extract_numbers(s): + """Extract all numbers from a string.""" + return [int(num) for num in re.findall(r"\d+", s)] + + # filter out items, so we only display closest item of each type + if unique_items: + closest_obj_info_list = defaultdict(str) + for item_name, loc in obj_info_list: + loc_dist = sum(extract_numbers(loc)) + current_dist = ( + sum(extract_numbers(closest_obj_info_list[item_name])) + if closest_obj_info_list[item_name] + else float("inf") ) - ) + + if loc_dist < current_dist: + closest_obj_info_list[item_name] = loc + obj_info_list = [(name, loc) for name, loc in closest_obj_info_list.items()] if len(obj_info_list) > 0: - status_str = "You see:\n{}".format( - "\n".join(["- {} {} steps to your {}".format(name, dist, loc) for name, dist, loc in obj_info_list]) - ) + status_str = "You see:\n{}".format("\n".join(["- {} {}".format(name, loc) for name, loc in obj_info_list])) else: status_str = "You see nothing away from you." result += status_str + "\n\n" @@ -167,19 +235,30 @@ def describe_status(info): return "" -def describe_frame(info): +def describe_frame( + info, + unique_items=True, + precise_location=False, + skip_items=[], + edge_only_items=[], +): try: result = "" result += describe_status(info) result += "\n\n" - result += describe_env(info) + result += describe_env( + info, + unique_items=unique_items, + precise_location=precise_location, + skip_items=skip_items, + edge_only_items=edge_only_items, + ) result += "\n\n" return result.strip(), describe_inventory(info) except Exception: - breakpoint() - return "Error, you are out of the map." + return "Error, you are out of the map.", describe_inventory(info) class CrafterLanguageWrapper(gym.Wrapper): @@ -191,6 +270,10 @@ def __init__( env, task="", max_episode_steps=2, + unique_items=True, + precise_location=False, + skip_items=[], + edge_only_items=[], ): super().__init__(env) self.score_tracker = 0 @@ -199,6 +282,11 @@ def __init__( self.max_steps = max_episode_steps self.achievements = None + self.unique_items = unique_items + self.precise_location = precise_location + self.skip_items = skip_items + self.edge_only_items = edge_only_items + def get_text_action(self, action): return self.language_action_space._values[action] @@ -232,7 +320,13 @@ def step(self, action): def process_obs(self, obs, info): img = Image.fromarray(self.env.render()).convert("RGB") - long_term_context, short_term_context = describe_frame(info) + long_term_context, short_term_context = describe_frame( + info, + unique_items=self.unique_items, + precise_location=self.precise_location, + skip_items=self.skip_items, + edge_only_items=self.edge_only_items, + ) return { "text": { diff --git a/setup.py b/setup.py index 5999890f..268f7a86 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ "opencv-python-headless", "wandb", "pytest", + "scipy", "crafter", "gym==0.23", "requests",