|
| 1 | +import os |
| 2 | +import random |
| 3 | +import readline |
| 4 | +import timeit |
| 5 | +from datetime import datetime |
| 6 | +from functools import partial |
| 7 | +from pathlib import Path |
| 8 | +from pprint import pprint |
| 9 | + |
| 10 | +import hydra |
| 11 | +import numpy as np |
| 12 | +from hydra.utils import get_original_cwd |
| 13 | +from omegaconf import DictConfig |
| 14 | + |
| 15 | +from balrog.agents import AgentFactory |
| 16 | +from balrog.environments import make_env |
| 17 | +from balrog.evaluator import EvaluatorManager |
| 18 | +from balrog.utils import get_unique_seed, setup_environment |
| 19 | + |
| 20 | + |
| 21 | +def completer(text, state, commands=[]): |
| 22 | + options = [cmd for cmd in commands if cmd.startswith(text)] |
| 23 | + return options[state] if state < len(options) else None |
| 24 | + |
| 25 | + |
| 26 | +def setup_autocomplete(completer_fn): |
| 27 | + readline.parse_and_bind("tab: complete") |
| 28 | + print("Type commands and use TAB to autocomplete.") |
| 29 | + print("To see strategies use command: `help`") |
| 30 | + readline.set_completer(completer_fn) |
| 31 | + |
| 32 | + |
| 33 | +def get_action(env, obs): |
| 34 | + language_action_space = env.get_wrapper_attr("language_action_space") |
| 35 | + setup_autocomplete(partial(completer, commands=language_action_space)) |
| 36 | + |
| 37 | + while True: |
| 38 | + command = input("> ") |
| 39 | + |
| 40 | + if command == "help": |
| 41 | + print(language_action_space) |
| 42 | + continue |
| 43 | + else: |
| 44 | + try: |
| 45 | + assert command in language_action_space |
| 46 | + break |
| 47 | + except Exception: |
| 48 | + print(f"Selected action '{command}' is not in action list. Please try again.") |
| 49 | + continue |
| 50 | + |
| 51 | + return command |
| 52 | + |
| 53 | + |
| 54 | +@hydra.main(config_path="balrog/config", config_name="config", version_base="1.1") |
| 55 | +def main(config: DictConfig): |
| 56 | + original_cwd = get_original_cwd() |
| 57 | + setup_environment(original_cwd=original_cwd) |
| 58 | + |
| 59 | + # Determine output directory |
| 60 | + if config.eval.resume_from is not None: |
| 61 | + output_dir = config.eval.resume_from |
| 62 | + else: |
| 63 | + now = datetime.now() |
| 64 | + timestamp = now.strftime("%Y-%m-%d_%H-%M-%S") |
| 65 | + run_name = f"{timestamp}_{config.agent.type}_{config.client.model_id.replace('/', '_')}" |
| 66 | + output_dir = os.path.join(config.eval.output_dir, run_name) |
| 67 | + |
| 68 | + # Create the directory if it doesn't exist |
| 69 | + Path(output_dir).mkdir(parents=True, exist_ok=True) |
| 70 | + |
| 71 | + env_name = random.choice(config.envs.names.split("-")) |
| 72 | + task = random.choice(config.tasks[f"{env_name}_tasks"]) |
| 73 | + print(f"Selected environment: {env_name}, task: {task}") |
| 74 | + |
| 75 | + env = make_env(env_name, task, config, render_mode="human") |
| 76 | + |
| 77 | + seed = config.envs.env_kwargs.seed |
| 78 | + if seed is None: |
| 79 | + seed = get_unique_seed(process_num=None, episode_idx=0) |
| 80 | + random.seed(seed) |
| 81 | + np.random.seed(seed) |
| 82 | + obs, info = env.reset(seed=seed) |
| 83 | + env.render() |
| 84 | + |
| 85 | + steps = 0 |
| 86 | + reward = 0.0 |
| 87 | + total_reward = 0.0 |
| 88 | + action = None |
| 89 | + |
| 90 | + total_start_time = timeit.default_timer() |
| 91 | + start_time = total_start_time |
| 92 | + |
| 93 | + while True: |
| 94 | + action = get_action(env, obs) |
| 95 | + if action is None: |
| 96 | + break |
| 97 | + |
| 98 | + obs, reward, terminated, truncated, info = env.step(action) |
| 99 | + env.render() |
| 100 | + |
| 101 | + steps += 1 |
| 102 | + total_reward += reward |
| 103 | + |
| 104 | + if not (terminated or truncated): |
| 105 | + continue |
| 106 | + |
| 107 | + time_delta = timeit.default_timer() - start_time |
| 108 | + |
| 109 | + print("Final reward:", reward) |
| 110 | + print(f"Total reward: {total_reward}, Steps: {steps}, SPS: {steps / time_delta}", total_reward) |
| 111 | + pprint.pprint(info) |
| 112 | + |
| 113 | + break |
| 114 | + env.close() |
| 115 | + |
| 116 | + |
| 117 | +if __name__ == "__main__": |
| 118 | + main() |
0 commit comments