Skip to content

Commit 8851290

Browse files
committed
quick script for playing with nle
1 parent 0d07d3f commit 8851290

File tree

1 file changed

+118
-0
lines changed

1 file changed

+118
-0
lines changed

play.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
import random
3+
import readline
4+
import timeit
5+
from datetime import datetime
6+
from functools import partial
7+
from pathlib import Path
8+
from pprint import pprint
9+
10+
import hydra
11+
import numpy as np
12+
from hydra.utils import get_original_cwd
13+
from omegaconf import DictConfig
14+
15+
from balrog.agents import AgentFactory
16+
from balrog.environments import make_env
17+
from balrog.evaluator import EvaluatorManager
18+
from balrog.utils import get_unique_seed, setup_environment
19+
20+
21+
def completer(text, state, commands=[]):
22+
options = [cmd for cmd in commands if cmd.startswith(text)]
23+
return options[state] if state < len(options) else None
24+
25+
26+
def setup_autocomplete(completer_fn):
27+
readline.parse_and_bind("tab: complete")
28+
print("Type commands and use TAB to autocomplete.")
29+
print("To see strategies use command: `help`")
30+
readline.set_completer(completer_fn)
31+
32+
33+
def get_action(env, obs):
34+
language_action_space = env.get_wrapper_attr("language_action_space")
35+
setup_autocomplete(partial(completer, commands=language_action_space))
36+
37+
while True:
38+
command = input("> ")
39+
40+
if command == "help":
41+
print(language_action_space)
42+
continue
43+
else:
44+
try:
45+
assert command in language_action_space
46+
break
47+
except Exception:
48+
print(f"Selected action '{command}' is not in action list. Please try again.")
49+
continue
50+
51+
return command
52+
53+
54+
@hydra.main(config_path="balrog/config", config_name="config", version_base="1.1")
55+
def main(config: DictConfig):
56+
original_cwd = get_original_cwd()
57+
setup_environment(original_cwd=original_cwd)
58+
59+
# Determine output directory
60+
if config.eval.resume_from is not None:
61+
output_dir = config.eval.resume_from
62+
else:
63+
now = datetime.now()
64+
timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
65+
run_name = f"{timestamp}_{config.agent.type}_{config.client.model_id.replace('/', '_')}"
66+
output_dir = os.path.join(config.eval.output_dir, run_name)
67+
68+
# Create the directory if it doesn't exist
69+
Path(output_dir).mkdir(parents=True, exist_ok=True)
70+
71+
env_name = random.choice(config.envs.names.split("-"))
72+
task = random.choice(config.tasks[f"{env_name}_tasks"])
73+
print(f"Selected environment: {env_name}, task: {task}")
74+
75+
env = make_env(env_name, task, config, render_mode="human")
76+
77+
seed = config.envs.env_kwargs.seed
78+
if seed is None:
79+
seed = get_unique_seed(process_num=None, episode_idx=0)
80+
random.seed(seed)
81+
np.random.seed(seed)
82+
obs, info = env.reset(seed=seed)
83+
env.render()
84+
85+
steps = 0
86+
reward = 0.0
87+
total_reward = 0.0
88+
action = None
89+
90+
total_start_time = timeit.default_timer()
91+
start_time = total_start_time
92+
93+
while True:
94+
action = get_action(env, obs)
95+
if action is None:
96+
break
97+
98+
obs, reward, terminated, truncated, info = env.step(action)
99+
env.render()
100+
101+
steps += 1
102+
total_reward += reward
103+
104+
if not (terminated or truncated):
105+
continue
106+
107+
time_delta = timeit.default_timer() - start_time
108+
109+
print("Final reward:", reward)
110+
print(f"Total reward: {total_reward}, Steps: {steps}, SPS: {steps / time_delta}", total_reward)
111+
pprint.pprint(info)
112+
113+
break
114+
env.close()
115+
116+
117+
if __name__ == "__main__":
118+
main()

0 commit comments

Comments
 (0)