Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ jobs:
run: "pip install nle"
- name: Check nethack is installed
run: |
python -c 'import nle; import gymnasium as gym; e = gym.make("NetHack-v0"); e.reset(); e.step(0)'
python -c 'import nle; import gym; e = gym.make("NetHack-v0"); e.reset(); e.step(0)'

13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@ set_target_properties(_pynethack PROPERTIES CXX_STANDARD 14)
target_include_directories(_pynethack PUBLIC ${NLE_INC_GEN})
add_dependencies(_pynethack util) # For pm.h.

pybind11_add_module(
nle_language_obsv
win/rl/language_wrapper/main.cpp
src/monst.c
src/decl.c
src/drawing.c
src/objects.c
)
target_link_libraries(nle_language_obsv PUBLIC nethackdl)
set_target_properties(nle_language_obsv PROPERTIES CXX_STANDARD 14)
target_include_directories(nle_language_obsv PUBLIC ${NLE_INC_GEN})
add_dependencies(nle_language_obsv util) # For pm.h.

# ttyrec converter library
add_library(
converter STATIC ${CMAKE_CURRENT_SOURCE_DIR}/third_party/converter/converter.c
Expand Down
1 change: 1 addition & 0 deletions nle/language_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from nle.language_wrapper.wrappers.nle_language_wrapper import NLELanguageWrapper
Empty file.
Empty file.
38 changes: 38 additions & 0 deletions nle/language_wrapper/agents/sample_factory/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from sample_factory.algorithms.utils.arguments import arg_parser
from sample_factory.algorithms.utils.arguments import parse_args


def custom_parse_args(argv=None, evaluation=False):
parser = arg_parser(argv, evaluation=evaluation)

# add custom args here
parser.add_argument(
"--nle_env_name", type=str, default="NetHackChallenge-v0", help=""
)
parser.add_argument(
"--transformer_hidden_size",
type=int,
default=64,
help="size of transformer hidden layers",
)
parser.add_argument(
"--transformer_hidden_layers",
type=int,
default=2,
help="number of transformer hidden layers",
)
parser.add_argument(
"--transformer_attention_heads",
type=int,
default=2,
help="number of transformer attention heads",
)
parser.add_argument(
"--max_token_length",
type=int,
default=256,
help="Maximum token input length before truncation",
)

cfg = parse_args(argv=argv, evaluation=evaluation, parser=parser)
return cfg
17 changes: 17 additions & 0 deletions nle/language_wrapper/agents/sample_factory/enjoy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sys

from sample_factory.algorithms.appo.enjoy_appo import enjoy

import nle.language_wrapper.agents.sample_factory.env # pylint: disable=['unused-import']
import nle.language_wrapper.agents.sample_factory.language_encoder # pylint: disable=['unused-import']
from nle.language_wrapper.agents.sample_factory.common import custom_parse_args


def main():
cfg = custom_parse_args(evaluation=True)
status = enjoy(cfg)
return status


if __name__ == "__main__":
sys.exit(main())
83 changes: 83 additions & 0 deletions nle/language_wrapper/agents/sample_factory/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from functools import lru_cache

import gym
import numpy as np
import torch
from sample_factory.envs.env_registry import global_env_registry
from transformers import RobertaTokenizerFast

from nle.language_wrapper import NLELanguageWrapper


class SampleFactoryNLELanguageEnv(gym.Env):
LRU_CACHE_SIZE = 1000

def __init__(self, cfg):
self.cfg = cfg
self.observation_space = gym.spaces.Dict()
self.observation_space.spaces["obs"] = gym.spaces.Box(
0, 1000000, shape=(1,), dtype=np.int32
)
self.observation_space.spaces["input_ids"] = gym.spaces.Box(
0, 1000000, shape=(self.cfg["max_token_length"],), dtype=np.int32
)
self.observation_space.spaces["attention_mask"] = gym.spaces.Box(
0, 1, shape=(self.cfg["max_token_length"],), dtype=np.int32
)
self.nle_env = gym.make(self.cfg["nle_env_name"])
self.env = NLELanguageWrapper(self.nle_env, use_language_action=False)
self.action_space = self.env.action_space
self.tokenizer = RobertaTokenizerFast.from_pretrained(
"distilroberta-base", truncation_side="left"
)

# We use caching to avoid re-tokenizing observations that are already seen.
@lru_cache(maxsize=LRU_CACHE_SIZE)
def _tokenize(self, str_obsv):
tokens = self.tokenizer(
str_obsv,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.cfg["max_token_length"],
)
# Sample factory insists on normalizing obs key.
tokens.data["obs"] = torch.zeros(1)
return tokens.data

def _convert_obsv_to_str(self, obsv):
text_obsv = ""
text_obsv += f"Inventory:\n{obsv['text_inventory']}\n\n"
text_obsv += f"Stats:\n{obsv['text_blstats']}\n\n"
text_obsv += f"Cursor:\n{obsv['text_cursor']}\n\n"
text_obsv += f"Stats:\n{obsv['text_glyphs']}\n\n"
text_obsv += f"Message:\n{obsv['text_message']}"
return text_obsv

def reset(self, **kwargs):
return self._tokenize(self._convert_obsv_to_str(self.env.reset(**kwargs)))

def step(self, action):
obsv, reward, done, info = self.env.step(action)
tokenized_obsv = self._tokenize(self._convert_obsv_to_str(obsv))
return tokenized_obsv, reward, done, info

def seed(self, *args): # pylint: disable=['unused-argument']
# Nethack does not allow seeding
return

def render(self, *args, **kwargs): # pylint: disable=['unused-argument']
self.env.render()


def make_custom_env_func(
full_env_name, cfg=None, env_config=None
): # pylint: disable=['unused-argument']
env = SampleFactoryNLELanguageEnv(cfg)
return env


global_env_registry().register_env(
env_name_prefix="nle_language_env",
make_env_func=make_custom_env_func,
)
58 changes: 58 additions & 0 deletions nle/language_wrapper/agents/sample_factory/language_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
from sample_factory.algorithms.appo.model_utils import EncoderBase
from sample_factory.algorithms.appo.model_utils import register_custom_encoder
from transformers import RobertaConfig
from transformers import RobertaModel


class NLELanguageTransformerEncoder(EncoderBase):
def __init__(self, cfg, obs_space, timing):
super().__init__(cfg, timing)
config = RobertaConfig(
attention_probs_dropout_prob=0.0,
bos_token_id=0,
classifier_dropout=None,
eos_token_id=2,
hidden_act="gelu",
hidden_dropout_prob=0.0,
hidden_size=cfg.transformer_hidden_size,
initializer_range=0.02,
intermediate_size=cfg.transformer_hidden_size,
layer_norm_eps=1e-05,
max_position_embeddings=obs_space.spaces["input_ids"].shape[0]
+ 2, # Roberta requires max sequence length + 2.
model_type="roberta",
num_attention_heads=cfg.transformer_attention_heads,
num_hidden_layers=cfg.transformer_hidden_layers,
pad_token_id=1,
position_embedding_type="absolute",
transformers_version="4.17.0",
type_vocab_size=1,
use_cache=False,
vocab_size=50265,
)
self.model = RobertaModel(config=config)
self.encoder_out_size = self.model.config.hidden_size

def device_and_type_for_input_tensor(
self, input_tensor_name
): # pylint: disable=['unused-argument']
return "cuda", torch.int32 # pylint: disable=['no-member']

def forward(self, obs_dict):
input_ids = obs_dict["input_ids"]
attention_mask = obs_dict["attention_mask"]
# Input transformation to allow for sample factory enjoy
if len(input_ids.shape) == 3:
input_ids = input_ids.squeeze(0)
attention_mask = attention_mask.squeeze(0)
if input_ids.dtype == torch.float32: # pylint: disable=['no-member']
input_ids = input_ids.long()
attention_mask = attention_mask.long()
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
return output.last_hidden_state[:, 0]


register_custom_encoder(
"nle_language_transformer_encoder", NLELanguageTransformerEncoder
)
26 changes: 26 additions & 0 deletions nle/language_wrapper/agents/sample_factory/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys

from sample_factory.algorithms.utils.arguments import arg_parser
from sample_factory.algorithms.utils.arguments import parse_args
from sample_factory.run_algorithm import run_algorithm

# Needs to be imported to register models and envs
import nle.language_wrapper.agents.sample_factory.env # pylint: disable=['unused-import']
import nle.language_wrapper.agents.sample_factory.language_encoder # pylint: disable=['unused-import']
from nle.language_wrapper.agents.sample_factory.common import custom_parse_args


def parse_all_args(argv=None, evaluation=False):
parser = arg_parser(argv, evaluation=evaluation)
cfg = parse_args(argv=argv, evaluation=evaluation, parser=parser)
return cfg


def main():
cfg = custom_parse_args()
status = run_algorithm(cfg)
return status


if __name__ == "__main__":
sys.exit(main())
Empty file.
125 changes: 125 additions & 0 deletions nle/language_wrapper/scripts/play.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import gym
import minihack # pylint: disable=unused-import
import nle # pylint: disable=unused-import
from minihack.scripts.env_list import skip_envs_list

from nle.language_wrapper import NLELanguageWrapper


def main(nethack_env_name):
"""
Play a NLE based environment using the nle-language-wrapper.
"""
env = NLELanguageWrapper(
gym.make(
nethack_env_name,
observation_keys=[
"glyphs",
"blstats",
"tty_chars",
"inv_letters",
"inv_strs",
"tty_cursor",
"tty_colors",
],
)
)
obsv = env.reset()
total_reward = 0.0
shown_help = False
done = False

while True:
output = ""
output += f"Inventory:\n{obsv['text_inventory']}\n\n"
output += f"Stats:\n{obsv['text_blstats']}\n\n"
output += f"Cursor:{obsv['text_cursor']}\n\n"
output += f"Observation:\n{obsv['text_glyphs']}\n\n"
output += f"Message:\n{obsv['text_message']}"
if output[-1] != "\n":
output += "\n"
print(output)
print("------")
if done:
return total_reward
valid_action = False
while not valid_action:
if not shown_help:
print(
'Type "instructions" anytime to see '
"the instructions on how to play"
)
print(
'Type "actions" anytime to see valid actions '
"for the current environment"
)
print('Type "render" anytime to render the nle environment')
shown_help = True
action = input("Action: ")
if action == "actions":
print("")
print("Actions")
max_key_len = max(
len(repr(k)) for k, _ in env.action_str_enum_map.items()
)
for action_str, nle_action_enum in env.action_str_enum_map.items():
key_str = f"{repr(action_str)}"
print(
key_str
+ (max_key_len - len(key_str)) * " "
+ f":{str(nle_action_enum)}"
)
print("")
continue
if action == "instructions":
print("")
print("Instructions")
print(
"To play, write a text action at the prompt, "
'e.g "north" will move you north.'
)
print(
"For the complete list of actions supported "
'in the current environment type "actions"'
)
print('To render the current display type "render"')
print("")
continue
if action == ("render"):
env.render()
continue
try:
(
obsv,
reward,
done,
_,
) = env.step(action)
total_reward += reward
valid_action = True
except ValueError as exception:
print(exception)
return total_reward


if __name__ == "__main__":
while True:
nle_env_names = [
env_spec.id
for env_spec in gym.envs.registry.all()
if "MiniHack" in env_spec.id and env_spec.id not in skip_envs_list
]
nle_env_names.append("NetHackChallenge-v0")
for i, env_name in enumerate(nle_env_names):
print(f"{i}: {env_name}")
selected_name = int(
input(
f"Which base env would you like to use [0-{len(nle_env_names) - 1}]? "
)
)
nle_env_name = nle_env_names[selected_name]
total_reward = main(nle_env_name)
print("Done!")
print(f"Total Reward: {total_reward}")
if input("Play again? [y,n]:") != "y":
break
Empty file.
Loading
Loading