Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion balrog/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ envs:
size: [256, 256] # Image size in Crafter
reward: True
seed: null
max_episode_steps: 2000
max_episode_steps: 2000
unique_items: True # False
precise_location: False # True
skip_items: [] # ["grass", "sand", "path"]
edge_only_items: [] # ["water", "lava"]
textworld_kwargs:
objective: True
description: True
Expand Down
15 changes: 13 additions & 2 deletions balrog/environments/crafter/crafter_env.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
from typing import Optional

import crafter

from balrog.environments.crafter import CrafterLanguageWrapper
from balrog.environments.wrappers import GymV21CompatibilityV0


def make_crafter_env(env_name, task, config, render_mode: Optional[str] = None):
crafter_kwargs = dict(config.envs.crafter_kwargs)
max_episode_steps = crafter_kwargs.pop("max_episode_steps", 2)
unique_items = crafter_kwargs.pop("unique_items", True)
precise_location = crafter_kwargs.pop("precise_location", False)
skip_items = crafter_kwargs.pop("skip_items", [])
edge_only_items = crafter_kwargs.pop("edge_only_items", [])

for param in ["area", "view", "size"]:
if param in crafter_kwargs:
crafter_kwargs[param] = tuple(crafter_kwargs[param])

env = crafter.Env(**crafter_kwargs)
env = CrafterLanguageWrapper(env, task, max_episode_steps=max_episode_steps)
env = CrafterLanguageWrapper(
env,
task,
max_episode_steps=max_episode_steps,
unique_items=unique_items,
precise_location=precise_location,
skip_items=skip_items,
edge_only_items=edge_only_items,
)
env = GymV21CompatibilityV0(env=env, render_mode=render_mode)

return env
150 changes: 122 additions & 28 deletions balrog/environments/crafter/env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import itertools
import re
from collections import defaultdict

import crafter
import gym
import numpy as np
from PIL import Image
from scipy import ndimage

from balrog.environments import Strings

Expand Down Expand Up @@ -77,7 +80,29 @@ def rotation_matrix(v1, v2):
return rotation_matrix


def describe_loc(ref, P):
def describe_loc_precise(ref, P):
"""
Describe the location of P relative to ref.
Example: `1 step south and 4 steps west`
"""
desc = []

def distange_to_string(distance, direction):
return f"{abs(distance)} step{'s' if abs(distance) > 1 else ''} {direction}"

if ref[1] > P[1]:
desc.append(distange_to_string(ref[1] - P[1], "north"))
elif ref[1] < P[1]:
desc.append(distange_to_string(ref[1] - P[1], "south"))
if ref[0] > P[0]:
desc.append(distange_to_string(ref[0] - P[0], "west"))
elif ref[0] < P[0]:
desc.append(distange_to_string(ref[0] - P[0], "east"))

return " and ".join(desc) if desc else "at your location"


def describe_loc_old(ref, P):
desc = []
if ref[1] > P[1]:
desc.append("north")
Expand All @@ -88,53 +113,96 @@ def describe_loc(ref, P):
elif ref[0] < P[0]:
desc.append("east")

return "-".join(desc)
distance = abs(ref[1] - P[1]) + abs(ref[0] - P[0])
distance_str = f"{distance} step{'s' if distance > 1 else ''} to your {'-'.join(desc)}"

return distance_str


def describe_env(info):
def get_edge_items(semantic, item_idx):
item_mask = semantic == item_idx
not_item_mask = semantic != item_idx
item_edge = ndimage.binary_dilation(not_item_mask) & item_mask
return item_edge


def describe_env(
info,
unique_items=True,
precise_location=False,
skip_items=[],
edge_only_items=[],
):
assert info["semantic"][info["player_pos"][0], info["player_pos"][1]] == player_idx
semantic = info["semantic"][
info["player_pos"][0] - info["view"][0] // 2 : info["player_pos"][0] + info["view"][0] // 2 + 1,
info["player_pos"][1] - info["view"][1] // 2 + 1 : info["player_pos"][1] + info["view"][1] // 2,
]
center = np.array([info["view"][0] // 2, info["view"][1] // 2 - 1])
result = ""
x = np.arange(semantic.shape[1])
y = np.arange(semantic.shape[0])
x1, y1 = np.meshgrid(x, y)
loc = np.stack((y1, x1), axis=-1)
dist = np.absolute(center - loc).sum(axis=-1)
describe_loc = describe_loc_precise if precise_location else describe_loc_old
obj_info_list = []

facing = info["player_facing"]
max_y, max_x = semantic.shape
max_x, max_y = semantic.shape
target_x = center[0] + facing[0]
target_y = center[1] + facing[1]

if 0 <= target_x < max_x and 0 <= target_y < max_y:
target_id = semantic[int(target_x), int(target_y)]
target_item = id_to_item[target_id]

# skip grass, sand or path so obs here, since we are not displaying them
if target_id in [id_to_item.index(o) for o in skip_items]:
target_item = "nothing"

obs = "You face {} at your front.".format(target_item)
else:
obs = "You face nothing at your front."

for idx in np.unique(semantic):
if idx == player_idx:
continue

smallest = np.unravel_index(np.argmin(np.where(semantic == idx, dist, np.inf)), semantic.shape)
obj_info_list.append(
(
id_to_item[idx],
dist[smallest],
describe_loc(np.array([0, 0]), smallest - center),
# Edge detection
edge_masks = {}
for item_name in edge_only_items:
item_idx = id_to_item.index(item_name)
edge_masks[item_idx] = get_edge_items(semantic, item_idx)

for i in range(semantic.shape[0]):
for j in range(semantic.shape[1]):
idx = semantic[i, j]
if idx == player_idx:
continue

# only display the edge of items that are in edge_only_items
if idx in edge_masks and not edge_masks[idx][i, j]:
continue

# skip grass, sand or path so obs is not too long
if idx in [id_to_item.index(o) for o in skip_items]:
continue

obj_info_list.append((id_to_item[idx], describe_loc(np.array([0, 0]), np.array([i, j]) - center)))

def extract_numbers(s):
"""Extract all numbers from a string."""
return [int(num) for num in re.findall(r"\d+", s)]

# filter out items, so we only display closest item of each type
if unique_items:
closest_obj_info_list = defaultdict(str)
for item_name, loc in obj_info_list:
loc_dist = sum(extract_numbers(loc))
current_dist = (
sum(extract_numbers(closest_obj_info_list[item_name]))
if closest_obj_info_list[item_name]
else float("inf")
)
)

if loc_dist < current_dist:
closest_obj_info_list[item_name] = loc
obj_info_list = [(name, loc) for name, loc in closest_obj_info_list.items()]

if len(obj_info_list) > 0:
status_str = "You see:\n{}".format(
"\n".join(["- {} {} steps to your {}".format(name, dist, loc) for name, dist, loc in obj_info_list])
)
status_str = "You see:\n{}".format("\n".join(["- {} {}".format(name, loc) for name, loc in obj_info_list]))
else:
status_str = "You see nothing away from you."
result += status_str + "\n\n"
Expand Down Expand Up @@ -167,19 +235,30 @@ def describe_status(info):
return ""


def describe_frame(info):
def describe_frame(
info,
unique_items=True,
precise_location=False,
skip_items=[],
edge_only_items=[],
):
try:
result = ""

result += describe_status(info)
result += "\n\n"
result += describe_env(info)
result += describe_env(
info,
unique_items=unique_items,
precise_location=precise_location,
skip_items=skip_items,
edge_only_items=edge_only_items,
)
result += "\n\n"

return result.strip(), describe_inventory(info)
except Exception:
breakpoint()
return "Error, you are out of the map."
return "Error, you are out of the map.", describe_inventory(info)


class CrafterLanguageWrapper(gym.Wrapper):
Expand All @@ -191,6 +270,10 @@ def __init__(
env,
task="",
max_episode_steps=2,
unique_items=True,
precise_location=False,
skip_items=[],
edge_only_items=[],
):
super().__init__(env)
self.score_tracker = 0
Expand All @@ -199,6 +282,11 @@ def __init__(
self.max_steps = max_episode_steps
self.achievements = None

self.unique_items = unique_items
self.precise_location = precise_location
self.skip_items = skip_items
self.edge_only_items = edge_only_items

def get_text_action(self, action):
return self.language_action_space._values[action]

Expand Down Expand Up @@ -232,7 +320,13 @@ def step(self, action):

def process_obs(self, obs, info):
img = Image.fromarray(self.env.render()).convert("RGB")
long_term_context, short_term_context = describe_frame(info)
long_term_context, short_term_context = describe_frame(
info,
unique_items=self.unique_items,
precise_location=self.precise_location,
skip_items=self.skip_items,
edge_only_items=self.edge_only_items,
)

return {
"text": {
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"opencv-python-headless",
"wandb",
"pytest",
"scipy",
"crafter",
"gym==0.23",
"requests",
Expand Down
Loading