Skip to content

Commit f22795d

Browse files
committed
nle language wrapper
1 parent 2046eee commit f22795d

File tree

20 files changed

+2836
-0
lines changed

20 files changed

+2836
-0
lines changed

CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ set_target_properties(_pynethack PROPERTIES CXX_STANDARD 14)
158158
target_include_directories(_pynethack PUBLIC ${NLE_INC_GEN})
159159
add_dependencies(_pynethack util) # For pm.h.
160160

161+
pybind11_add_module(
162+
nle_language_obsv
163+
win/rl/language_wrapper/main.cpp
164+
src/monst.c
165+
src/decl.c
166+
src/drawing.c
167+
src/objects.c
168+
)
169+
target_link_libraries(nle_language_obsv PUBLIC nethackdl)
170+
set_target_properties(nle_language_obsv PROPERTIES CXX_STANDARD 14)
171+
target_include_directories(nle_language_obsv PUBLIC ${NLE_INC_GEN})
172+
add_dependencies(nle_language_obsv util) # For pm.h.
173+
161174
# ttyrec converter library
162175
add_library(
163176
converter STATIC ${CMAKE_CURRENT_SOURCE_DIR}/third_party/converter/converter.c

nle/language_wrapper/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from nle.language_wrapper.wrappers.nle_language_wrapper import NLELanguageWrapper

nle/language_wrapper/agents/__init__.py

Whitespace-only changes.

nle/language_wrapper/agents/sample_factory/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from sample_factory.algorithms.utils.arguments import arg_parser
2+
from sample_factory.algorithms.utils.arguments import parse_args
3+
4+
5+
def custom_parse_args(argv=None, evaluation=False):
6+
parser = arg_parser(argv, evaluation=evaluation)
7+
8+
# add custom args here
9+
parser.add_argument(
10+
"--nle_env_name", type=str, default="NetHackChallenge-v0", help=""
11+
)
12+
parser.add_argument(
13+
"--transformer_hidden_size",
14+
type=int,
15+
default=64,
16+
help="size of transformer hidden layers",
17+
)
18+
parser.add_argument(
19+
"--transformer_hidden_layers",
20+
type=int,
21+
default=2,
22+
help="number of transformer hidden layers",
23+
)
24+
parser.add_argument(
25+
"--transformer_attention_heads",
26+
type=int,
27+
default=2,
28+
help="number of transformer attention heads",
29+
)
30+
parser.add_argument(
31+
"--max_token_length",
32+
type=int,
33+
default=256,
34+
help="Maximum token input length before truncation",
35+
)
36+
37+
cfg = parse_args(argv=argv, evaluation=evaluation, parser=parser)
38+
return cfg
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import sys
2+
3+
from sample_factory.algorithms.appo.enjoy_appo import enjoy
4+
5+
import nle.language_wrapper.agents.sample_factory.env # pylint: disable=['unused-import']
6+
import nle.language_wrapper.agents.sample_factory.language_encoder # pylint: disable=['unused-import']
7+
from nle.language_wrapper.agents.sample_factory.common import custom_parse_args
8+
9+
10+
def main():
11+
cfg = custom_parse_args(evaluation=True)
12+
status = enjoy(cfg)
13+
return status
14+
15+
16+
if __name__ == "__main__":
17+
sys.exit(main())
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from functools import lru_cache
2+
3+
import gymnasium as gym
4+
import numpy as np
5+
import torch
6+
from sample_factory.envs.env_registry import global_env_registry
7+
from transformers import RobertaTokenizerFast
8+
9+
from nle.language_wrapper import NLELanguageWrapper
10+
11+
12+
class SampleFactoryNLELanguageEnv(gym.Env):
13+
LRU_CACHE_SIZE = 1000
14+
15+
def __init__(self, cfg):
16+
self.cfg = cfg
17+
self.observation_space = gym.spaces.Dict()
18+
self.observation_space.spaces["obs"] = gym.spaces.Box(
19+
0, 1000000, shape=(1,), dtype=np.int32
20+
)
21+
self.observation_space.spaces["input_ids"] = gym.spaces.Box(
22+
0, 1000000, shape=(self.cfg["max_token_length"],), dtype=np.int32
23+
)
24+
self.observation_space.spaces["attention_mask"] = gym.spaces.Box(
25+
0, 1, shape=(self.cfg["max_token_length"],), dtype=np.int32
26+
)
27+
self.nle_env = gym.make(self.cfg["nle_env_name"])
28+
self.env = NLELanguageWrapper(self.nle_env, use_language_action=False)
29+
self.action_space = self.env.action_space
30+
self.tokenizer = RobertaTokenizerFast.from_pretrained(
31+
"distilroberta-base", truncation_side="left"
32+
)
33+
34+
# We use caching to avoid re-tokenizing observations that are already seen.
35+
@lru_cache(maxsize=LRU_CACHE_SIZE)
36+
def _tokenize(self, str_obsv):
37+
tokens = self.tokenizer(
38+
str_obsv,
39+
return_tensors="pt",
40+
padding="max_length",
41+
truncation=True,
42+
max_length=self.cfg["max_token_length"],
43+
)
44+
# Sample factory insists on normalizing obs key.
45+
tokens.data["obs"] = torch.zeros(1)
46+
return tokens.data
47+
48+
def _convert_obsv_to_str(self, obsv):
49+
text_obsv = ""
50+
text_obsv += f"Inventory:\n{obsv['text_inventory']}\n\n"
51+
text_obsv += f"Stats:\n{obsv['text_blstats']}\n\n"
52+
text_obsv += f"Cursor:\n{obsv['text_cursor']}\n\n"
53+
text_obsv += f"Stats:\n{obsv['text_glyphs']}\n\n"
54+
text_obsv += f"Message:\n{obsv['text_message']}"
55+
return text_obsv
56+
57+
def reset(self, *, seed=None, **kwargs):
58+
super().reset(seed=seed)
59+
obsv, info = self.env.reset(**kwargs)
60+
return self._tokenize(self._convert_obsv_to_str(obsv)), info
61+
62+
def step(self, action):
63+
obsv, reward, term, trun, info = self.env.step(action)
64+
tokenized_obsv = self._tokenize(self._convert_obsv_to_str(obsv))
65+
return tokenized_obsv, reward, term, trun, info
66+
67+
def seed(self, *args): # pylint: disable=['unused-argument']
68+
# Nethack does not allow seeding
69+
return
70+
71+
def render(self, *args, **kwargs): # pylint: disable=['unused-argument']
72+
self.env.render()
73+
74+
75+
def make_custom_env_func(
76+
full_env_name, cfg=None, env_config=None
77+
): # pylint: disable=['unused-argument']
78+
env = SampleFactoryNLELanguageEnv(cfg)
79+
return env
80+
81+
82+
global_env_registry().register_env(
83+
env_name_prefix="nle_language_env",
84+
make_env_func=make_custom_env_func,
85+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
from sample_factory.algorithms.appo.model_utils import EncoderBase
3+
from sample_factory.algorithms.appo.model_utils import register_custom_encoder
4+
from transformers import RobertaConfig
5+
from transformers import RobertaModel
6+
7+
8+
class NLELanguageTransformerEncoder(EncoderBase):
9+
def __init__(self, cfg, obs_space, timing):
10+
super().__init__(cfg, timing)
11+
config = RobertaConfig(
12+
attention_probs_dropout_prob=0.0,
13+
bos_token_id=0,
14+
classifier_dropout=None,
15+
eos_token_id=2,
16+
hidden_act="gelu",
17+
hidden_dropout_prob=0.0,
18+
hidden_size=cfg.transformer_hidden_size,
19+
initializer_range=0.02,
20+
intermediate_size=cfg.transformer_hidden_size,
21+
layer_norm_eps=1e-05,
22+
max_position_embeddings=obs_space.spaces["input_ids"].shape[0]
23+
+ 2, # Roberta requires max sequence length + 2.
24+
model_type="roberta",
25+
num_attention_heads=cfg.transformer_attention_heads,
26+
num_hidden_layers=cfg.transformer_hidden_layers,
27+
pad_token_id=1,
28+
position_embedding_type="absolute",
29+
transformers_version="4.17.0",
30+
type_vocab_size=1,
31+
use_cache=False,
32+
vocab_size=50265,
33+
)
34+
self.model = RobertaModel(config=config)
35+
self.encoder_out_size = self.model.config.hidden_size
36+
37+
def device_and_type_for_input_tensor(
38+
self, input_tensor_name
39+
): # pylint: disable=['unused-argument']
40+
return "cuda", torch.int32 # pylint: disable=['no-member']
41+
42+
def forward(self, obs_dict):
43+
input_ids = obs_dict["input_ids"]
44+
attention_mask = obs_dict["attention_mask"]
45+
# Input transformation to allow for sample factory enjoy
46+
if len(input_ids.shape) == 3:
47+
input_ids = input_ids.squeeze(0)
48+
attention_mask = attention_mask.squeeze(0)
49+
if input_ids.dtype == torch.float32: # pylint: disable=['no-member']
50+
input_ids = input_ids.long()
51+
attention_mask = attention_mask.long()
52+
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
53+
return output.last_hidden_state[:, 0]
54+
55+
56+
register_custom_encoder(
57+
"nle_language_transformer_encoder", NLELanguageTransformerEncoder
58+
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import sys
2+
3+
from sample_factory.algorithms.utils.arguments import arg_parser
4+
from sample_factory.algorithms.utils.arguments import parse_args
5+
from sample_factory.run_algorithm import run_algorithm
6+
7+
# Needs to be imported to register models and envs
8+
import nle.language_wrapper.agents.sample_factory.env # pylint: disable=['unused-import']
9+
import nle.language_wrapper.agents.sample_factory.language_encoder # pylint: disable=['unused-import']
10+
from nle.language_wrapper.agents.sample_factory.common import custom_parse_args
11+
12+
13+
def parse_all_args(argv=None, evaluation=False):
14+
parser = arg_parser(argv, evaluation=evaluation)
15+
cfg = parse_args(argv=argv, evaluation=evaluation, parser=parser)
16+
return cfg
17+
18+
19+
def main():
20+
cfg = custom_parse_args()
21+
status = run_algorithm(cfg)
22+
return status
23+
24+
25+
if __name__ == "__main__":
26+
sys.exit(main())

nle/language_wrapper/scripts/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)