diff --git a/nle/agent/Dockerfile b/nle/agent/Dockerfile new file mode 100644 index 000000000..13e32b50d --- /dev/null +++ b/nle/agent/Dockerfile @@ -0,0 +1,66 @@ +# -*- mode: dockerfile -*- +FROM nvidia/cuda:11.1.1-cudnn8-devel-ubuntu20.04 + +ARG PYTHON_VERSION=3.8 +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -yq \ + bison \ + build-essential \ + cmake \ + curl \ + flex \ + git \ + libbz2-dev \ + ninja-build \ + software-properties-common \ + wget \ + apt-transport-https \ + ca-certificates \ + gnupg + +# Install the latest cmake +RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - +RUN apt-add-repository 'deb https://apt.kitware.com/ubuntu/ focal main' +RUN apt-get update && apt-get --allow-unauthenticated install -yq cmake kitware-archive-keyring + +# Install Conda +WORKDIR /opt/conda_setup +RUN curl -o miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + chmod +x miniconda.sh && \ + ./miniconda.sh -b -p /opt/conda && \ + /opt/conda/bin/conda install -y python=$PYTHON_VERSION && \ + /opt/conda/bin/conda clean -ya +ENV PATH /opt/conda/bin:$PATH + +# Create Env, Install Torch and Keep Env active +RUN conda init bash +RUN conda create -n nle python=3.7 +RUN conda install -n nle pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia +ENV BASH_ENV ~/.bashrc +SHELL ["conda", "run", "-n", "nle", "/bin/bash" ,"-c"] +RUN python -c 'import torch' + +# Install TorchBeast +WORKDIR /opt/ +RUN git clone https://github.com/condnsdmatters/torchbeast.git --branch eric/experimental-port --recursive + +WORKDIR /opt/torchbeast +RUN pip install -r requirements.txt +RUN pip install ./nest +RUN python setup.py install + +# Create Workspace +WORKDIR /opt/workspace +RUN pip install nle \ + hydra-core \ + hydra_colorlog \ + wandb \ + einops + +RUN echo "conda activate nle" >> ~/.bashrc +CMD ["/bin/bash"] + +# Docker commands: +# docker build -t nle . +# docker run -v current_dir:/opt/workspace -it nle diff --git a/nle/agent/__init__.py b/nle/agent/__init__.py deleted file mode 100644 index 9020c2df2..000000000 --- a/nle/agent/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/nle/agent/agent.py b/nle/agent/agent.py deleted file mode 100644 index 21f888840..000000000 --- a/nle/agent/agent.py +++ /dev/null @@ -1,956 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This is an example self-contained agent running NLE based on MonoBeast. - -import argparse -import logging -import os -import pprint -import threading -import time -import timeit -import traceback - -# Necessary for multithreading. -os.environ["OMP_NUM_THREADS"] = "1" - -try: - import torch - from torch import multiprocessing as mp - from torch import nn - from torch.nn import functional as F -except ImportError: - logging.exception( - "PyTorch not found. Please install the agent dependencies with " - '`pip install "nle[agent]"`' - ) - -import gym # noqa: E402 - -import nle # noqa: F401, E402 -from nle.agent import vtrace # noqa: E402 -from nle import nethack # noqa: E402 - - -# yapf: disable -parser = argparse.ArgumentParser(description="PyTorch Scalable Agent") - -parser.add_argument("--env", type=str, default="NetHackScore-v0", - help="Gym environment.") -parser.add_argument("--mode", default="train", - choices=["train", "test", "test_render"], - help="Training or test mode.") - -# Training settings. -parser.add_argument("--disable_checkpoint", action="store_true", - help="Disable saving checkpoint.") -parser.add_argument("--savedir", default="~/torchbeast/", - help="Root dir where experiment data will be saved.") -parser.add_argument("--num_actors", default=4, type=int, metavar="N", - help="Number of actors (default: 4).") -parser.add_argument("--total_steps", default=100000, type=int, metavar="T", - help="Total environment steps to train for.") -parser.add_argument("--batch_size", default=8, type=int, metavar="B", - help="Learner batch size.") -parser.add_argument("--unroll_length", default=80, type=int, metavar="T", - help="The unroll length (time dimension).") -parser.add_argument("--num_buffers", default=None, type=int, - metavar="N", help="Number of shared-memory buffers.") -parser.add_argument("--num_learner_threads", "--num_threads", default=2, type=int, - metavar="N", help="Number learner threads.") -parser.add_argument("--disable_cuda", action="store_true", - help="Disable CUDA.") -parser.add_argument("--use_lstm", action="store_true", - help="Use LSTM in agent model.") - -# Loss settings. -parser.add_argument("--entropy_cost", default=0.0006, - type=float, help="Entropy cost/multiplier.") -parser.add_argument("--baseline_cost", default=0.5, - type=float, help="Baseline cost/multiplier.") -parser.add_argument("--discounting", default=0.99, - type=float, help="Discounting factor.") -parser.add_argument("--reward_clipping", default="abs_one", - choices=["abs_one", "none"], - help="Reward clipping.") - -# Optimizer settings. -parser.add_argument("--learning_rate", default=0.00048, - type=float, metavar="LR", help="Learning rate.") -parser.add_argument("--alpha", default=0.99, type=float, - help="RMSProp smoothing constant.") -parser.add_argument("--momentum", default=0, type=float, - help="RMSProp momentum.") -parser.add_argument("--epsilon", default=0.01, type=float, - help="RMSProp epsilon.") -parser.add_argument("--grad_norm_clipping", default=40.0, type=float, - help="Global gradient norm clip.") -# yapf: enable - - -logging.basicConfig( - format=( - "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" - ), - level=logging.INFO, -) - - -def nested_map(f, n): - if isinstance(n, tuple) or isinstance(n, list): - return n.__class__(nested_map(f, sn) for sn in n) - elif isinstance(n, dict): - return {k: nested_map(f, v) for k, v in n.items()} - else: - return f(n) - - -def compute_baseline_loss(advantages): - return 0.5 * torch.sum(advantages ** 2) - - -def compute_entropy_loss(logits): - """Return the entropy loss, i.e., the negative entropy of the policy.""" - policy = F.softmax(logits, dim=-1) - log_policy = F.log_softmax(logits, dim=-1) - return torch.sum(policy * log_policy) - - -def compute_policy_gradient_loss(logits, actions, advantages): - cross_entropy = F.nll_loss( - F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), - target=torch.flatten(actions, 0, 1), - reduction="none", - ) - cross_entropy = cross_entropy.view_as(advantages) - return torch.sum(cross_entropy * advantages.detach()) - - -def create_env(name, *args, **kwargs): - return gym.make(name, observation_keys=("glyphs", "blstats"), *args, **kwargs) - - -def act( - flags, - actor_index: int, - free_queue: mp.SimpleQueue, - full_queue: mp.SimpleQueue, - model: torch.nn.Module, - buffers, - initial_agent_state_buffers, -): - try: - logging.info("Actor %i started.", actor_index) - - gym_env = create_env(flags.env, savedir=flags.rundir) - env = ResettingEnvironment(gym_env) - env_output = env.initial() - agent_state = model.initial_state(batch_size=1) - agent_output, unused_state = model(env_output, agent_state) - while True: - index = free_queue.get() - if index is None: - break - - # Write old rollout end. - for key in env_output: - buffers[key][index][0, ...] = env_output[key] - for key in agent_output: - buffers[key][index][0, ...] = agent_output[key] - for i, tensor in enumerate(agent_state): - initial_agent_state_buffers[index][i][...] = tensor - - # Do new rollout. - for t in range(flags.unroll_length): - with torch.no_grad(): - agent_output, agent_state = model(env_output, agent_state) - - env_output = env.step(agent_output["action"]) - - for key in env_output: - buffers[key][index][t + 1, ...] = env_output[key] - for key in agent_output: - buffers[key][index][t + 1, ...] = agent_output[key] - - full_queue.put(index) - - except KeyboardInterrupt: - pass # Return silently. - except Exception: - logging.error("Exception in worker process %i", actor_index) - traceback.print_exc() - print() - raise - - -def get_batch( - flags, - free_queue: mp.SimpleQueue, - full_queue: mp.SimpleQueue, - buffers, - initial_agent_state_buffers, - lock=threading.Lock(), -): - with lock: - indices = [full_queue.get() for _ in range(flags.batch_size)] - batch = { - key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers - } - initial_agent_state = ( - torch.cat(ts, dim=1) - for ts in zip(*[initial_agent_state_buffers[m] for m in indices]) - ) - for m in indices: - free_queue.put(m) - batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()} - initial_agent_state = tuple( - t.to(device=flags.device, non_blocking=True) for t in initial_agent_state - ) - return batch, initial_agent_state - - -def learn( - flags, - actor_model, - model, - batch, - initial_agent_state, - optimizer, - scheduler, - lock=threading.Lock(), # noqa: B008 -): - """Performs a learning (optimization) step.""" - with lock: - learner_outputs, unused_state = model(batch, initial_agent_state) - - # Take final value function slice for bootstrapping. - bootstrap_value = learner_outputs["baseline"][-1] - - # Move from obs[t] -> action[t] to action[t] -> obs[t]. - batch = {key: tensor[1:] for key, tensor in batch.items()} - learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()} - - rewards = batch["reward"] - if flags.reward_clipping == "abs_one": - clipped_rewards = torch.clamp(rewards, -1, 1) - elif flags.reward_clipping == "none": - clipped_rewards = rewards - - discounts = (~batch["done"]).float() * flags.discounting - - vtrace_returns = vtrace.from_logits( - behavior_policy_logits=batch["policy_logits"], - target_policy_logits=learner_outputs["policy_logits"], - actions=batch["action"], - discounts=discounts, - rewards=clipped_rewards, - values=learner_outputs["baseline"], - bootstrap_value=bootstrap_value, - ) - - pg_loss = compute_policy_gradient_loss( - learner_outputs["policy_logits"], - batch["action"], - vtrace_returns.pg_advantages, - ) - baseline_loss = flags.baseline_cost * compute_baseline_loss( - vtrace_returns.vs - learner_outputs["baseline"] - ) - entropy_loss = flags.entropy_cost * compute_entropy_loss( - learner_outputs["policy_logits"] - ) - - total_loss = pg_loss + baseline_loss + entropy_loss - - episode_returns = batch["episode_return"][batch["done"]] - stats = { - "episode_returns": tuple(episode_returns.cpu().numpy()), - "mean_episode_return": torch.mean(episode_returns).item(), - "total_loss": total_loss.item(), - "pg_loss": pg_loss.item(), - "baseline_loss": baseline_loss.item(), - "entropy_loss": entropy_loss.item(), - } - - optimizer.zero_grad() - total_loss.backward() - nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) - optimizer.step() - scheduler.step() - - actor_model.load_state_dict(model.state_dict()) - return stats - - -def create_buffers(flags, observation_space, num_actions, num_overlapping_steps=1): - size = (flags.unroll_length + num_overlapping_steps,) - - # Get specimens to infer shapes and dtypes. - samples = {k: torch.from_numpy(v) for k, v in observation_space.sample().items()} - - specs = { - key: dict(size=size + sample.shape, dtype=sample.dtype) - for key, sample in samples.items() - } - specs.update( - reward=dict(size=size, dtype=torch.float32), - done=dict(size=size, dtype=torch.bool), - episode_return=dict(size=size, dtype=torch.float32), - episode_step=dict(size=size, dtype=torch.int32), - policy_logits=dict(size=size + (num_actions,), dtype=torch.float32), - baseline=dict(size=size, dtype=torch.float32), - last_action=dict(size=size, dtype=torch.int64), - action=dict(size=size, dtype=torch.int64), - ) - buffers = {key: [] for key in specs} - for _ in range(flags.num_buffers): - for key in buffers: - buffers[key].append(torch.empty(**specs[key]).share_memory_()) - return buffers - - -def _format_observations(observation, keys=("glyphs", "blstats")): - observations = {} - for key in keys: - entry = observation[key] - entry = torch.from_numpy(entry) - entry = entry.view((1, 1) + entry.shape) # (...) -> (T,B,...). - observations[key] = entry - return observations - - -class ResettingEnvironment: - """Turns a Gym environment into something that can be step()ed indefinitely.""" - - def __init__(self, gym_env): - self.gym_env = gym_env - self.episode_return = None - self.episode_step = None - - def initial(self): - initial_reward = torch.zeros(1, 1) - # This supports only single-tensor actions ATM. - initial_last_action = torch.zeros(1, 1, dtype=torch.int64) - self.episode_return = torch.zeros(1, 1) - self.episode_step = torch.zeros(1, 1, dtype=torch.int32) - initial_done = torch.ones(1, 1, dtype=torch.uint8) - - result = _format_observations(self.gym_env.reset()) - result.update( - reward=initial_reward, - done=initial_done, - episode_return=self.episode_return, - episode_step=self.episode_step, - last_action=initial_last_action, - ) - return result - - def step(self, action): - observation, reward, done, unused_info = self.gym_env.step(action.item()) - self.episode_step += 1 - self.episode_return += reward - episode_step = self.episode_step - episode_return = self.episode_return - if done: - observation = self.gym_env.reset() - self.episode_return = torch.zeros(1, 1) - self.episode_step = torch.zeros(1, 1, dtype=torch.int32) - - result = _format_observations(observation) - - reward = torch.tensor(reward).view(1, 1) - done = torch.tensor(done).view(1, 1) - - result.update( - reward=reward, - done=done, - episode_return=episode_return, - episode_step=episode_step, - last_action=action, - ) - return result - - def close(self): - self.gym_env.close() - - -def train(flags): # pylint: disable=too-many-branches, too-many-statements - flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) - - rundir = os.path.join( - flags.savedir, "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S") - ) - - if not os.path.exists(rundir): - os.makedirs(rundir) - logging.info("Logging results to %s", rundir) - - symlink = os.path.join(flags.savedir, "latest") - try: - if os.path.islink(symlink): - os.remove(symlink) - if not os.path.exists(symlink): - os.symlink(rundir, symlink) - logging.info("Symlinked log directory: %s", symlink) - except OSError: - raise - - logfile = open(os.path.join(rundir, "logs.tsv"), "a", buffering=1) - checkpointpath = os.path.join(rundir, "model.tar") - - flags.rundir = rundir - - if flags.num_buffers is None: # Set sensible default for num_buffers. - flags.num_buffers = max(2 * flags.num_actors, flags.batch_size) - if flags.num_actors >= flags.num_buffers: - raise ValueError("num_buffers should be larger than num_actors") - if flags.num_buffers < flags.batch_size: - raise ValueError("num_buffers should be larger than batch_size") - - T = flags.unroll_length - B = flags.batch_size - - flags.device = None - if not flags.disable_cuda and torch.cuda.is_available(): - logging.info("Using CUDA.") - flags.device = torch.device("cuda") - else: - logging.info("Not using CUDA.") - flags.device = torch.device("cpu") - - env = create_env(flags.env, archivefile=None) - observation_space = env.observation_space - action_space = env.action_space - del env # End this before forking. - - model = Net(observation_space, action_space.n, flags.use_lstm) - buffers = create_buffers(flags, observation_space, model.num_actions) - - model.share_memory() - - # Add initial RNN state. - initial_agent_state_buffers = [] - for _ in range(flags.num_buffers): - state = model.initial_state(batch_size=1) - for t in state: - t.share_memory_() - initial_agent_state_buffers.append(state) - - actor_processes = [] - ctx = mp.get_context("fork") - free_queue = ctx.SimpleQueue() - full_queue = ctx.SimpleQueue() - - for i in range(flags.num_actors): - actor = ctx.Process( - target=act, - args=( - flags, - i, - free_queue, - full_queue, - model, - buffers, - initial_agent_state_buffers, - ), - name="Actor-%i" % i, - ) - actor.start() - actor_processes.append(actor) - - learner_model = Net(observation_space, action_space.n, flags.use_lstm).to( - device=flags.device - ) - learner_model.load_state_dict(model.state_dict()) - - optimizer = torch.optim.RMSprop( - learner_model.parameters(), - lr=flags.learning_rate, - momentum=flags.momentum, - eps=flags.epsilon, - alpha=flags.alpha, - ) - - def lr_lambda(epoch): - return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps - - scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) - - stat_keys = [ - "total_loss", - "mean_episode_return", - "pg_loss", - "baseline_loss", - "entropy_loss", - ] - logfile.write("# Step\t%s\n" % "\t".join(stat_keys)) - - step, stats = 0, {} - - def batch_and_learn(i, lock=threading.Lock()): - """Thread target for the learning process.""" - nonlocal step, stats - while step < flags.total_steps: - batch, agent_state = get_batch( - flags, free_queue, full_queue, buffers, initial_agent_state_buffers - ) - stats = learn( - flags, model, learner_model, batch, agent_state, optimizer, scheduler - ) - with lock: - logfile.write("%i\t" % step) - logfile.write("\t".join(str(stats[k]) for k in stat_keys)) - logfile.write("\n") - step += T * B - - for m in range(flags.num_buffers): - free_queue.put(m) - - threads = [] - for i in range(flags.num_learner_threads): - thread = threading.Thread( - target=batch_and_learn, - name="batch-and-learn-%d" % i, - args=(i,), - daemon=True, # To support KeyboardInterrupt below. - ) - thread.start() - threads.append(thread) - - def checkpoint(): - if flags.disable_checkpoint: - return - logging.info("Saving checkpoint to %s", checkpointpath) - torch.save( - { - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": scheduler.state_dict(), - "flags": vars(flags), - }, - checkpointpath, - ) - - timer = timeit.default_timer - try: - last_checkpoint_time = timer() - while step < flags.total_steps: - start_step = step - start_time = timer() - time.sleep(5) - - if timer() - last_checkpoint_time > 10 * 60: # Save every 10 min. - checkpoint() - last_checkpoint_time = timer() - - sps = (step - start_step) / (timer() - start_time) - if stats.get("episode_returns", None): - mean_return = ( - "Return per episode: %.1f. " % stats["mean_episode_return"] - ) - else: - mean_return = "" - total_loss = stats.get("total_loss", float("inf")) - logging.info( - "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s", - step, - sps, - total_loss, - mean_return, - pprint.pformat(stats), - ) - except KeyboardInterrupt: - logging.warning("Quitting.") - return # Try joining actors then quit. - else: - for thread in threads: - thread.join() - logging.info("Learning finished after %d steps.", step) - finally: - for _ in range(flags.num_actors): - free_queue.put(None) - for actor in actor_processes: - actor.join(timeout=1) - - checkpoint() - logfile.close() - - -def test(flags, num_episodes=10): - flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir)) - checkpointpath = os.path.join(flags.savedir, "latest", "model.tar") - - gym_env = create_env(flags.env, archivefile=None) - env = ResettingEnvironment(gym_env) - model = Net(gym_env.observation_space, gym_env.action_space.n, flags.use_lstm) - model.eval() - checkpoint = torch.load(checkpointpath, map_location="cpu") - model.load_state_dict(checkpoint["model_state_dict"]) - - observation = env.initial() - returns = [] - - agent_state = model.initial_state(batch_size=1) - - while len(returns) < num_episodes: - if flags.mode == "test_render": - env.gym_env.render() - policy_outputs, agent_state = model(observation, agent_state) - observation = env.step(policy_outputs["action"]) - if observation["done"].item(): - returns.append(observation["episode_return"].item()) - logging.info( - "Episode ended after %d steps. Return: %.1f", - observation["episode_step"].item(), - observation["episode_return"].item(), - ) - env.close() - logging.info( - "Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns) - ) - - -class RandomNet(nn.Module): - def __init__(self, observation_shape, num_actions, use_lstm): - super(RandomNet, self).__init__() - del observation_shape, use_lstm - self.num_actions = num_actions - self.theta = torch.nn.Parameter(torch.zeros(self.num_actions)) - - def forward(self, inputs, core_state): - # print(inputs) - T, B, *_ = inputs["observation"].shape - zeros = self.theta * 0 - # set logits to 0 - policy_logits = zeros[None, :].expand(T * B, -1) - # set baseline to 0 - baseline = policy_logits.sum(dim=1).view(-1, B) - - # sample random action - action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1).view( - T, B - ) - policy_logits = policy_logits.view(T, B, self.num_actions) - return ( - dict(policy_logits=policy_logits, baseline=baseline, action=action), - core_state, - ) - - def initial_state(self, batch_size): - return () - - -def _step_to_range(delta, num_steps): - """Range of `num_steps` integers with distance `delta` centered around zero.""" - return delta * torch.arange(-num_steps // 2, num_steps // 2) - - -class Crop(nn.Module): - """Helper class for NetHackNet below.""" - - def __init__(self, height, width, height_target, width_target): - super(Crop, self).__init__() - self.width = width - self.height = height - self.width_target = width_target - self.height_target = height_target - width_grid = _step_to_range(2 / (self.width - 1), self.width_target)[ - None, : - ].expand(self.height_target, -1) - height_grid = _step_to_range(2 / (self.height - 1), height_target)[ - :, None - ].expand(-1, self.width_target) - - # "clone" necessary, https://github.com/pytorch/pytorch/issues/34880 - self.register_buffer("width_grid", width_grid.clone()) - self.register_buffer("height_grid", height_grid.clone()) - - def forward(self, inputs, coordinates): - """Calculates centered crop around given x,y coordinates. - Args: - inputs [B x H x W] - coordinates [B x 2] x,y coordinates - Returns: - [B x H' x W'] inputs cropped and centered around x,y coordinates. - """ - assert inputs.shape[1] == self.height - assert inputs.shape[2] == self.width - - inputs = inputs[:, None, :, :].float() - - x = coordinates[:, 0] - y = coordinates[:, 1] - - x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2) - y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2) - - grid = torch.stack( - [ - self.width_grid[None, :, :] + x_shift[:, None, None], - self.height_grid[None, :, :] + y_shift[:, None, None], - ], - dim=3, - ) - - # TODO: only cast to int if original tensor was int - return ( - torch.round(F.grid_sample(inputs, grid, align_corners=True)) - .squeeze(1) - .long() - ) - - -class NetHackNet(nn.Module): - def __init__( - self, - observation_shape, - num_actions, - use_lstm, - embedding_dim=32, - crop_dim=9, - num_layers=5, - ): - super(NetHackNet, self).__init__() - - self.glyph_shape = observation_shape["glyphs"].shape - self.blstats_size = observation_shape["blstats"].shape[0] - - self.num_actions = num_actions - self.use_lstm = use_lstm - - self.H = self.glyph_shape[0] - self.W = self.glyph_shape[1] - - self.k_dim = embedding_dim - self.h_dim = 512 - - self.crop_dim = crop_dim - - self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim) - - self.embed = nn.Embedding(nethack.MAX_GLYPH, self.k_dim) - - K = embedding_dim # number of input filters - F = 3 # filter dimensions - S = 1 # stride - P = 1 # padding - M = 16 # number of intermediate filters - Y = 8 # number of output filters - L = num_layers # number of convnet layers - - in_channels = [K] + [M] * (L - 1) - out_channels = [M] * (L - 1) + [Y] - - def interleave(xs, ys): - return [val for pair in zip(xs, ys) for val in pair] - - conv_extract = [ - nn.Conv2d( - in_channels=in_channels[i], - out_channels=out_channels[i], - kernel_size=(F, F), - stride=S, - padding=P, - ) - for i in range(L) - ] - - self.extract_representation = nn.Sequential( - *interleave(conv_extract, [nn.ELU()] * len(conv_extract)) - ) - - # CNN crop model. - conv_extract_crop = [ - nn.Conv2d( - in_channels=in_channels[i], - out_channels=out_channels[i], - kernel_size=(F, F), - stride=S, - padding=P, - ) - for i in range(L) - ] - - self.extract_crop_representation = nn.Sequential( - *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract)) - ) - - out_dim = self.k_dim - # CNN over full glyph map - out_dim += self.H * self.W * Y - - # CNN crop model. - out_dim += self.crop_dim ** 2 * Y - - self.embed_blstats = nn.Sequential( - nn.Linear(self.blstats_size, self.k_dim), - nn.ReLU(), - nn.Linear(self.k_dim, self.k_dim), - nn.ReLU(), - ) - - self.fc = nn.Sequential( - nn.Linear(out_dim, self.h_dim), - nn.ReLU(), - nn.Linear(self.h_dim, self.h_dim), - nn.ReLU(), - ) - - if self.use_lstm: - self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1) - - self.policy = nn.Linear(self.h_dim, self.num_actions) - self.baseline = nn.Linear(self.h_dim, 1) - - def initial_state(self, batch_size=1): - if not self.use_lstm: - return tuple() - return tuple( - torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) - for _ in range(2) - ) - - def _select(self, embed, x): - # Work around slow backward pass of nn.Embedding, see - # https://github.com/pytorch/pytorch/issues/24912 - out = embed.weight.index_select(0, x.reshape(-1)) - return out.reshape(x.shape + (-1,)) - - def forward(self, env_outputs, core_state): - # -- [T x B x H x W] - glyphs = env_outputs["glyphs"] - - # -- [T x B x F] - blstats = env_outputs["blstats"] - - T, B, *_ = glyphs.shape - - # -- [B' x H x W] - glyphs = torch.flatten(glyphs, 0, 1) # Merge time and batch. - - # -- [B' x F] - blstats = blstats.view(T * B, -1).float() - - # -- [B x H x W] - glyphs = glyphs.long() - # -- [B x 2] x,y coordinates - coordinates = blstats[:, :2] - # TODO ??? - # coordinates[:, 0].add_(-1) - - # -- [B x F] - # FIXME: hack to use compatible blstats to before - # blstats = blstats[:, [0, 1, 21, 10, 11]] - - blstats = blstats.view(T * B, -1).float() - # -- [B x K] - blstats_emb = self.embed_blstats(blstats) - - assert blstats_emb.shape[0] == T * B - - reps = [blstats_emb] - - # -- [B x H' x W'] - crop = self.crop(glyphs, coordinates) - - # print("crop", crop) - # print("at_xy", glyphs[:, coordinates[:, 1].long(), coordinates[:, 0].long()]) - - # -- [B x H' x W' x K] - crop_emb = self._select(self.embed, crop) - - # CNN crop model. - # -- [B x K x W' x H'] - crop_emb = crop_emb.transpose(1, 3) # -- TODO: slow? - # -- [B x W' x H' x K] - crop_rep = self.extract_crop_representation(crop_emb) - - # -- [B x K'] - crop_rep = crop_rep.view(T * B, -1) - assert crop_rep.shape[0] == T * B - - reps.append(crop_rep) - - # -- [B x H x W x K] - glyphs_emb = self._select(self.embed, glyphs) - # glyphs_emb = self.embed(glyphs) - # -- [B x K x W x H] - glyphs_emb = glyphs_emb.transpose(1, 3) # -- TODO: slow? - # -- [B x W x H x K] - glyphs_rep = self.extract_representation(glyphs_emb) - - # -- [B x K'] - glyphs_rep = glyphs_rep.view(T * B, -1) - - assert glyphs_rep.shape[0] == T * B - - # -- [B x K''] - reps.append(glyphs_rep) - - st = torch.cat(reps, dim=1) - - # -- [B x K] - st = self.fc(st) - - if self.use_lstm: - core_input = st.view(T, B, -1) - core_output_list = [] - notdone = (~env_outputs["done"]).float() - for input, nd in zip(core_input.unbind(), notdone.unbind()): - # Reset core state to zero whenever an episode ended. - # Make `done` broadcastable with (num_layers, B, hidden_size) - # states: - nd = nd.view(1, -1, 1) - core_state = tuple(nd * s for s in core_state) - output, core_state = self.core(input.unsqueeze(0), core_state) - core_output_list.append(output) - core_output = torch.flatten(torch.cat(core_output_list), 0, 1) - else: - core_output = st - - # -- [B x A] - policy_logits = self.policy(core_output) - # -- [B x A] - baseline = self.baseline(core_output) - - if self.training: - action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) - else: - # Don't sample when testing. - action = torch.argmax(policy_logits, dim=1) - - policy_logits = policy_logits.view(T, B, self.num_actions) - baseline = baseline.view(T, B) - action = action.view(T, B) - - return ( - dict(policy_logits=policy_logits, baseline=baseline, action=action), - core_state, - ) - - -Net = NetHackNet - - -def main(flags): - if flags.mode == "train": - train(flags) - else: - test(flags) - - -if __name__ == "__main__": - flags = parser.parse_args() - main(flags) diff --git a/nle/agent/config.yaml b/nle/agent/config.yaml new file mode 100644 index 000000000..70f5c7699 --- /dev/null +++ b/nle/agent/config.yaml @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +defaults: +- hydra/job_logging: colorlog +- hydra/hydra_logging: colorlog +# - hydra/launcher: submitit_slurm + +# # To Be Used With hydra submitit_slurm if you have SLURM cluster +# # pip install hydra-core hydra_colorlog +# # can set these on the commandline too, e.g. `hydra.launcher.partition=dev` +# hydra: +# launcher: +# timeout_min: 4300 +# cpus_per_task: 20 +# gpus_per_node: 2 +# tasks_per_node: 1 +# mem_gb: 20 +# nodes: 1 +# partition: dev +# comment: null +# max_num_timeout: 5 # will requeue on timeout or preemption + + +name: null # can use this to have multiple runs with same params, eg name=1,2,3,4,5 + +## WANDB settings +wandb: false # Enable wandb logging. +project: nethack_challenge # The wandb project name. +entity: user1 # The wandb user to log to. +group: group1 # The wandb group for the run. + +# POLYBEAST ENV settings +mock: false # Use mock environment instead of NetHack. +single_ttyrec: true # Record ttyrec only for actor 0. +num_seeds: 0 # If larger than 0, samples fixed number of environment seeds to be used.' +write_profiler_trace: false # Collect and write a profiler trace for chrome://tracing/. +fn_penalty_step: constant # Function to accumulate penalty. +penalty_time: 0.0 # Penalty per time step in the episode. +penalty_step: -0.01 # Penalty per step in the episode. +reward_lose: 0 # Reward for losing (dying before finding the staircase). +reward_win: 100 # Reward for winning (finding the staircase). +state_counter: none # Method for counting state visits. Default none. +character: '@' # Specification of the NetHack character. + ## typical characters we use + # 'mon-hum-neu-mal' + # 'val-dwa-law-fem' + # 'wiz-elf-cha-mal' + # 'tou-hum-neu-fem' + # '@' # random (used in Challenge assessment) + +# RUN settings. +mode: train # Training or test mode. +env: challenge # Name of Gym environment to create. + # # env (task) names: challenge, staircase, pet, + # eat, gold, score, scout, oracle + +# TRAINING settings. +num_actors: 256 # Number of actors. +total_steps: 1e9 # Total environment steps to train for. Will be cast to int. +batch_size: 32 # Learner batch size. +unroll_length: 80 # The unroll length (time dimension). +num_learner_threads: 1 # Number learner threads. +num_inference_threads: 1 # Number inference threads. +disable_cuda: false # Disable CUDA. +learner_device: cuda:1 # Set learner device. +actor_device: cuda:0 # Set actor device. + +# OPTIMIZER settings. (RMS Prop) +learning_rate: 0.0002 # Learning rate. +grad_norm_clipping: 40 # Global gradient norm clip. +alpha: 0.99 # RMSProp smoothing constant. +momentum: 0 # RMSProp momentum. +epsilon: 0.000001 # RMSProp epsilon. + +# LOSS settings. +entropy_cost: 0.001 # Entropy cost/multiplier. +baseline_cost: 0.5 # Baseline cost/multiplier. +discounting: 0.999 # Discounting factor. +normalize_reward: true # Normalizes reward by dividing by running stdev from mean. + +# MODEL settings. +model: baseline # Name of model to build (see models/__init__.py). +use_lstm: true # Use LSTM in agent model. +hidden_dim: 256 # Size of hidden representations. +embedding_dim: 64 # Size of glyph embeddings. +layers: 5 # Number of ConvNet Layers for Glyph Model +crop_dim: 9 # Size of crop (c x c) +use_index_select: true # Whether to use index_select instead of embedding lookup (for speed reasons). +restrict_action_space: True # Use a restricted ACTION SPACE (only nethack.USEFUL_ACTIONS) + +msg: + hidden_dim: 64 # Hidden dimension for message encoder. + embedding_dim: 32 # Embedding dimension for characters in message encoder. + +# TEST settings. +load_dir: null # Path to load a model from for testing diff --git a/nle/agent/core/file_writer.py b/nle/agent/core/file_writer.py new file mode 100644 index 000000000..99b06fe4f --- /dev/null +++ b/nle/agent/core/file_writer.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import csv +import datetime +import json +import logging +import os +import time +import weakref + + +def _save_metadata(path, metadata): + metadata["date_save"] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + with open(path, "w") as f: + json.dump(metadata, f, indent=4, sort_keys=True) + + +def gather_metadata(): + metadata = dict( + date_start=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + env=os.environ.copy(), + successful=False, + ) + + # Git metadata. + try: + import git + except ImportError: + logging.warning( + "Couldn't import gitpython module; install it with `pip install gitpython`." + ) + else: + try: + repo = git.Repo(search_parent_directories=True) + metadata["git"] = { + "commit": repo.commit().hexsha, + "is_dirty": repo.is_dirty(), + "path": repo.git_dir, + } + if not repo.head.is_detached: + metadata["git"]["branch"] = repo.active_branch.name + except git.InvalidGitRepositoryError: + pass + + if "git" not in metadata: + logging.warning("Couldn't determine git data.") + + # Slurm metadata. + if "SLURM_JOB_ID" in os.environ: + slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")] + metadata["slurm"] = {} + for k in slurm_env_keys: + d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower() + metadata["slurm"][d_key] = os.environ[k] + + return metadata + + +class FileWriter: + def __init__(self, xp_args=None, rootdir="~/palaas"): + if rootdir == "~/palaas": + # make unique id in case someone uses the default rootdir + xpid = "{proc}_{unixtime}".format( + proc=os.getpid(), unixtime=int(time.time()) + ) + rootdir = os.path.join(rootdir, xpid) + self.basepath = os.path.expandvars(os.path.expanduser(rootdir)) + + self._tick = 0 + + # metadata gathering + if xp_args is None: + xp_args = {} + self.metadata = gather_metadata() + # we need to copy the args, otherwise when we close the file writer + # (and rewrite the args) we might have non-serializable objects (or + # other nasty stuff). + self.metadata["args"] = copy.deepcopy(xp_args) + + formatter = logging.Formatter("%(message)s") + self._logger = logging.getLogger("palaas/out") + + # to stdout handler + shandle = logging.StreamHandler() + shandle.setFormatter(formatter) + self._logger.addHandler(shandle) + self._logger.setLevel(logging.INFO) + + # to file handler + if not os.path.exists(self.basepath): + self._logger.info("Creating log directory: %s", self.basepath) + os.makedirs(self.basepath, exist_ok=True) + else: + self._logger.info("Found log directory: %s", self.basepath) + + self.paths = dict( + msg="{base}/out.log".format(base=self.basepath), + logs="{base}/logs.csv".format(base=self.basepath), + fields="{base}/fields.csv".format(base=self.basepath), + meta="{base}/meta.json".format(base=self.basepath), + ) + + self._logger.info("Saving arguments to %s", self.paths["meta"]) + if os.path.exists(self.paths["meta"]): + self._logger.warning( + "Path to meta file already exists. " "Not overriding meta." + ) + else: + self.save_metadata() + + self._logger.info("Saving messages to %s", self.paths["msg"]) + if os.path.exists(self.paths["msg"]): + self._logger.warning( + "Path to message file already exists. " "New data will be appended." + ) + + fhandle = logging.FileHandler(self.paths["msg"]) + fhandle.setFormatter(formatter) + self._logger.addHandler(fhandle) + + self._logger.info("Saving logs data to %s", self.paths["logs"]) + self._logger.info("Saving logs' fields to %s", self.paths["fields"]) + self.fieldnames = ["_tick", "_time"] + if os.path.exists(self.paths["logs"]): + self._logger.warning( + "Path to log file already exists. " "New data will be appended." + ) + # Override default fieldnames. + with open(self.paths["fields"], "r") as csvfile: + reader = csv.reader(csvfile) + lines = list(reader) + if len(lines) > 0: + self.fieldnames = lines[-1] + # Override default tick: use the last tick from the logs file plus 1. + with open(self.paths["logs"], "r") as csvfile: + reader = csv.reader(csvfile) + lines = list(reader) + # Need at least two lines in order to read the last tick: + # the first is the csv header and the second is the first line + # of data. + if len(lines) > 1: + self._tick = int(lines[-1][0]) + 1 + + self._fieldfile = open(self.paths["fields"], "a") + self._fieldwriter = csv.writer(self._fieldfile) + self._fieldfile.flush() + self._logfile = open(self.paths["logs"], "a") + self._logwriter = csv.DictWriter(self._logfile, fieldnames=self.fieldnames) + + # Auto-close (and save) on destruction. + weakref.finalize(self, _save_metadata, self.paths["meta"], self.metadata) + + def log(self, to_log, tick=None, verbose=False): + if tick is not None: + raise NotImplementedError + else: + to_log["_tick"] = self._tick + self._tick += 1 + to_log["_time"] = time.time() + + old_len = len(self.fieldnames) + for k in to_log: + if k not in self.fieldnames: + self.fieldnames.append(k) + if old_len != len(self.fieldnames): + self._fieldwriter.writerow(self.fieldnames) + self._fieldfile.flush() + self._logger.info("Updated log fields: %s", self.fieldnames) + + if to_log["_tick"] == 0: + self._logfile.write("# %s\n" % ",".join(self.fieldnames)) + + if verbose: + self._logger.info( + "LOG | %s", + ", ".join(["{}: {}".format(k, to_log[k]) for k in sorted(to_log)]), + ) + + self._logwriter.writerow(to_log) + self._logfile.flush() + + def close(self, successful=True): + self.metadata["successful"] = successful + self.save_metadata() + + for f in [self._logfile, self._fieldfile]: + f.close() + + def save_metadata(self): + _save_metadata(self.paths["meta"], self.metadata) diff --git a/nle/agent/vtrace.py b/nle/agent/core/vtrace.py similarity index 100% rename from nle/agent/vtrace.py rename to nle/agent/core/vtrace.py diff --git a/nle/agent/models/__init__.py b/nle/agent/models/__init__.py new file mode 100644 index 000000000..2d2da8329 --- /dev/null +++ b/nle/agent/models/__init__.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nle.env import tasks +from nle.env.base import DUNGEON_SHAPE + +from .baseline import BaselineNet + +from omegaconf import OmegaConf +import torch + + +ENVS = dict( + staircase=tasks.NetHackStaircase, + score=tasks.NetHackScore, + pet=tasks.NetHackStaircasePet, + oracle=tasks.NetHackOracle, + gold=tasks.NetHackGold, + eat=tasks.NetHackEat, + scout=tasks.NetHackScout, + challenge=tasks.NetHackChallenge, +) + + +def create_model(flags, device): + model_string = flags.model + if model_string == "baseline": + model_cls = BaselineNet + else: + raise NotImplementedError("model=%s" % model_string) + + action_space = ENVS[flags.env](savedir=None, archivefile=None)._actions + + model = model_cls(DUNGEON_SHAPE, action_space, flags, device) + model.to(device=device) + return model + + +def load_model(load_dir, device): + flags = OmegaConf.load(load_dir + "/config.yaml") + flags.checkpoint = load_dir + "/checkpoint.tar" + model = create_model(flags, device) + checkpoint_states = torch.load(flags.checkpoint, map_location=device) + model.load_state_dict(checkpoint_states["model_state_dict"]) + return model diff --git a/nle/agent/models/baseline.py b/nle/agent/models/baseline.py new file mode 100644 index 000000000..99fdde1c6 --- /dev/null +++ b/nle/agent/models/baseline.py @@ -0,0 +1,496 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange + + +from nle import nethack + +from util.id_pairs import id_pairs_table +import numpy as np + +NUM_GLYPHS = nethack.MAX_GLYPH +NUM_FEATURES = nethack.BLSTATS_SHAPE[0] +PAD_CHAR = 0 +NUM_CHARS = 256 + + +def get_action_space_mask(action_space, reduced_action_space): + mask = np.array([int(a in reduced_action_space) for a in action_space]) + return torch.Tensor(mask) + + +def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1): + """Return the dimension after applying a convolution along one axis""" + return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride) + + +def select(embedding_layer, x, use_index_select): + """Use index select instead of default forward to possible speed up embedding.""" + if use_index_select: + out = embedding_layer.weight.index_select(0, x.view(-1)) + # handle reshaping x to 1-d and output back to N-d + return out.view(x.shape + (-1,)) + else: + return embedding_layer(x) + + +class NetHackNet(nn.Module): + """This base class simply provides a skeleton for running with torchbeast.""" + + AgentOutput = collections.namedtuple("AgentOutput", "action policy_logits baseline") + + def __init__(self): + super(NetHackNet, self).__init__() + self.register_buffer("reward_sum", torch.zeros(())) + self.register_buffer("reward_m2", torch.zeros(())) + self.register_buffer("reward_count", torch.zeros(()).fill_(1e-8)) + + def forward(self, inputs, core_state): + raise NotImplementedError + + def initial_state(self, batch_size=1): + return () + + @torch.no_grad() + def update_running_moments(self, reward_batch): + """Maintains a running mean of reward.""" + new_count = len(reward_batch) + new_sum = torch.sum(reward_batch) + new_mean = new_sum / new_count + + curr_mean = self.reward_sum / self.reward_count + new_m2 = torch.sum((reward_batch - new_mean) ** 2) + ( + (self.reward_count * new_count) + / (self.reward_count + new_count) + * (new_mean - curr_mean) ** 2 + ) + + self.reward_count += new_count + self.reward_sum += new_sum + self.reward_m2 += new_m2 + + @torch.no_grad() + def get_running_std(self): + """Returns standard deviation of the running mean of the reward.""" + return torch.sqrt(self.reward_m2 / self.reward_count) + + +class BaselineNet(NetHackNet): + """This model combines the encodings of the glyphs, top line message and + blstats into a single fixed-size representation, which is then passed to + an LSTM core before generating a policy and value head for use in an IMPALA + like architecture. + + This model was based on 'neurips2020release' tag on the NLE repo, itself + based on Kuttler et al, 2020 + The NetHack Learning Environment + https://arxiv.org/abs/2006.13760 + """ + + def __init__(self, observation_shape, action_space, flags, device): + super(BaselineNet, self).__init__() + + self.flags = flags + + self.observation_shape = observation_shape + self.num_actions = len(action_space) + + self.H = observation_shape[0] + self.W = observation_shape[1] + + self.use_lstm = flags.use_lstm + self.h_dim = flags.hidden_dim + + # GLYPH + CROP MODEL + self.glyph_model = GlyphEncoder(flags, self.H, self.W, flags.crop_dim, device) + + # MESSAGING MODEL + self.msg_model = MessageEncoder( + flags.msg.hidden_dim, flags.msg.embedding_dim, device + ) + + # BLSTATS MODEL + self.blstats_model = BLStatsEncoder(NUM_FEATURES, flags.embedding_dim) + + out_dim = ( + self.blstats_model.hidden_dim + + self.glyph_model.hidden_dim + + self.msg_model.hidden_dim + ) + + self.fc = nn.Sequential( + nn.Linear(out_dim, self.h_dim), + nn.ReLU(), + nn.Linear(self.h_dim, self.h_dim), + nn.ReLU(), + ) + + if self.use_lstm: + self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1) + + self.policy = nn.Linear(self.h_dim, self.num_actions) + self.baseline = nn.Linear(self.h_dim, 1) + + if flags.restrict_action_space: + reduced_space = nethack.USEFUL_ACTIONS + logits_mask = get_action_space_mask(action_space, reduced_space) + self.policy_logits_mask = nn.parameter.Parameter( + logits_mask, requires_grad=False + ) + + def initial_state(self, batch_size=1): + return tuple( + torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size) + for _ in range(2) + ) + + def forward(self, inputs, core_state, learning=False): + T, B, H, W = inputs["glyphs"].shape + + reps = [] + + # -- [B' x K] ; B' == (T x B) + glyphs_rep = self.glyph_model(inputs) + reps.append(glyphs_rep) + + # -- [B' x K] + char_rep = self.msg_model(inputs) + reps.append(char_rep) + + # -- [B' x K] + features_emb = self.blstats_model(inputs) + reps.append(features_emb) + + # -- [B' x K] + st = torch.cat(reps, dim=1) + + # -- [B' x K] + st = self.fc(st) + + if self.use_lstm: + core_input = st.view(T, B, -1) + core_output_list = [] + notdone = (~inputs["done"]).float() + for input, nd in zip(core_input.unbind(), notdone.unbind()): + # Reset core state to zero whenever an episode ended. + # Make `done` broadcastable with (num_layers, B, hidden_size) + # states: + nd = nd.view(1, -1, 1) + core_state = tuple(nd * t for t in core_state) + output, core_state = self.core(input.unsqueeze(0), core_state) + core_output_list.append(output) + core_output = torch.flatten(torch.cat(core_output_list), 0, 1) + else: + core_output = st + + # -- [B' x A] + policy_logits = self.policy(core_output) + + # -- [B' x 1] + baseline = self.baseline(core_output) + + if self.flags.restrict_action_space: + policy_logits = policy_logits * self.policy_logits_mask + ( + (1 - self.policy_logits_mask) * -1e10 + ) + + if self.training: + action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1) + else: + # Don't sample when testing. + action = torch.argmax(policy_logits, dim=1) + + policy_logits = policy_logits.view(T, B, -1) + baseline = baseline.view(T, B) + action = action.view(T, B) + + output = dict(policy_logits=policy_logits, baseline=baseline, action=action) + return (output, core_state) + + +class GlyphEncoder(nn.Module): + """This glyph encoder first breaks the glyphs (integers up to 6000) to a + more structured representation based on the qualities of the glyph: chars, + colors, specials, groups and subgroup ids.. + Eg: invisible hell-hound: char (d), color (red), specials (invisible), + group (monster) subgroup id (type of monster) + Eg: lit dungeon floor: char (.), color (white), specials (none), + group (dungeon) subgroup id (type of dungeon) + + An embedding is provided for each of these, and the embeddings are + concatenated, before encoding with a number of CNN layers. This operation + is repeated with a crop of the structured reprentations taken around the + characters position, and the two representations are concatenated + before returning. + """ + + def __init__(self, flags, rows, cols, crop_dim, device=None): + super(GlyphEncoder, self).__init__() + + self.crop = Crop(rows, cols, crop_dim, crop_dim, device) + K = flags.embedding_dim # number of input filters + L = flags.layers # number of convnet layers + + assert ( + K % 8 == 0 + ), "This glyph embedding format needs embedding dim to be multiple of 8" + unit = K // 8 + self.chars_embedding = nn.Embedding(256, 2 * unit) + self.colors_embedding = nn.Embedding(16, unit) + self.specials_embedding = nn.Embedding(256, unit) + + self.id_pairs_table = nn.parameter.Parameter( + torch.from_numpy(id_pairs_table()), requires_grad=False + ) + num_groups = self.id_pairs_table.select(1, 1).max().item() + 1 + num_ids = self.id_pairs_table.select(1, 0).max().item() + 1 + + self.groups_embedding = nn.Embedding(num_groups, unit) + self.ids_embedding = nn.Embedding(num_ids, 3 * unit) + + F = 3 # filter dimensions + S = 1 # stride + P = 1 # padding + M = 16 # number of intermediate filters + self.output_filters = 8 + + in_channels = [K] + [M] * (L - 1) + out_channels = [M] * (L - 1) + [self.output_filters] + + h, w, c = rows, cols, crop_dim + conv_extract, conv_extract_crop = [], [] + for i in range(L): + conv_extract.append( + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + ) + conv_extract.append(nn.ELU()) + + conv_extract_crop.append( + nn.Conv2d( + in_channels=in_channels[i], + out_channels=out_channels[i], + kernel_size=(F, F), + stride=S, + padding=P, + ) + ) + conv_extract_crop.append(nn.ELU()) + + # Keep track of output shapes + h = conv_outdim(h, F, P, S) + w = conv_outdim(w, F, P, S) + c = conv_outdim(c, F, P, S) + + self.hidden_dim = (h * w + c * c) * self.output_filters + self.extract_representation = nn.Sequential(*conv_extract) + self.extract_crop_representation = nn.Sequential(*conv_extract_crop) + self.select = lambda emb, x: select(emb, x, flags.use_index_select) + + def glyphs_to_ids_groups(self, glyphs): + T, B, H, W = glyphs.shape + ids_groups = self.id_pairs_table.index_select(0, glyphs.view(-1).long()) + ids = ids_groups.select(1, 0).view(T, B, H, W).long() + groups = ids_groups.select(1, 1).view(T, B, H, W).long() + return [ids, groups] + + def forward(self, inputs): + T, B, H, W = inputs["glyphs"].shape + ids, groups = self.glyphs_to_ids_groups(inputs["glyphs"]) + + glyph_tensors = [ + self.select(self.chars_embedding, inputs["chars"].long()), + self.select(self.colors_embedding, inputs["colors"].long()), + self.select(self.specials_embedding, inputs["specials"].long()), + self.select(self.groups_embedding, groups), + self.select(self.ids_embedding, ids), + ] + + glyphs_emb = torch.cat(glyph_tensors, dim=-1) + glyphs_emb = rearrange(glyphs_emb, "T B H W K -> (T B) K H W") + + coordinates = inputs["blstats"].view(T * B, -1).float()[:, :2] + crop_emb = self.crop(glyphs_emb, coordinates) + + glyphs_rep = self.extract_representation(glyphs_emb) + glyphs_rep = rearrange(glyphs_rep, "B C H W -> B (C H W)") + assert glyphs_rep.shape[0] == T * B + + crop_rep = self.extract_crop_representation(crop_emb) + crop_rep = rearrange(crop_rep, "B C H W -> B (C H W)") + assert crop_rep.shape[0] == T * B + + st = torch.cat([glyphs_rep, crop_rep], dim=1) + return st + + +class MessageEncoder(nn.Module): + """This model encodes the the topline message into a fixed size representation. + + It works by using a learnt embedding for each character before passing the + embeddings through 6 CNN layers. + + Inspired by Zhang et al, 2016 + Character-level Convolutional Networks for Text Classification + https://arxiv.org/abs/1509.01626 + """ + + def __init__(self, hidden_dim, embedding_dim, device=None): + super(MessageEncoder, self).__init__() + + self.hidden_dim = hidden_dim + self.msg_edim = embedding_dim + + self.char_lt = nn.Embedding(NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR) + self.conv1 = nn.Conv1d(self.msg_edim, self.hidden_dim, kernel_size=7) + self.conv2_6_fc = nn.Sequential( + nn.ReLU(), + nn.MaxPool1d(kernel_size=3, stride=3), + # conv2 + nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=7), + nn.ReLU(), + nn.MaxPool1d(kernel_size=3, stride=3), + # conv3 + nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3), + nn.ReLU(), + # conv4 + nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3), + nn.ReLU(), + # conv5 + nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3), + nn.ReLU(), + # conv6 + nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3), + nn.ReLU(), + nn.MaxPool1d(kernel_size=3, stride=3), + # fc receives -- [ B x h_dim x 5 ] + Flatten(), + nn.Linear(5 * self.hidden_dim, 2 * self.hidden_dim), + nn.ReLU(), + nn.Linear(2 * self.hidden_dim, self.hidden_dim), + ) # final output -- [ B x h_dim x 5 ] + + def forward(self, inputs): + T, B, *_ = inputs["message"].shape + messages = inputs["message"].long().view(T * B, -1) + # [ T * B x E x 256 ] + char_emb = self.char_lt(messages).transpose(1, 2) + char_rep = self.conv2_6_fc(self.conv1(char_emb)) + return char_rep + + +class BLStatsEncoder(nn.Module): + """This model encodes the bottom line stats into a fixed size representation. + + It works by simply using two fully-connected layers with ReLU activations. + """ + + def __init__(self, num_features, hidden_dim): + super(BLStatsEncoder, self).__init__() + self.num_features = num_features + self.hidden_dim = hidden_dim + self.embed_features = nn.Sequential( + nn.Linear(self.num_features, self.hidden_dim), + nn.ReLU(), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.ReLU(), + ) + + def forward(self, inputs): + T, B, *_ = inputs["blstats"].shape + + features = inputs["blstats"] + # -- [B' x F] + features = features.view(T * B, -1).float() + # -- [B x K] + features_emb = self.embed_features(features) + + assert features_emb.shape[0] == T * B + return features_emb + + +class Crop(nn.Module): + def __init__(self, height, width, height_target, width_target, device=None): + super(Crop, self).__init__() + self.width = width + self.height = height + self.width_target = width_target + self.height_target = height_target + + width_grid = self._step_to_range(2 / (self.width - 1), self.width_target) + self.width_grid = width_grid[None, :].expand(self.height_target, -1) + + height_grid = self._step_to_range(2 / (self.height - 1), height_target) + self.height_grid = height_grid[:, None].expand(-1, self.width_target) + + if device is not None: + self.width_grid = self.width_grid.to(device) + self.height_grid = self.height_grid.to(device) + + def _step_to_range(self, step, num_steps): + return torch.tensor([step * (i - num_steps // 2) for i in range(num_steps)]) + + def forward(self, inputs, coordinates): + """Calculates centered crop around given x,y coordinates. + + Args: + inputs [B x H x W] or [B x C x H x W] + coordinates [B x 2] x,y coordinates + + Returns: + [B x C x H' x W'] inputs cropped and centered around x,y coordinates. + """ + if inputs.dim() == 3: + inputs = inputs.unsqueeze(1).float() + + assert inputs.shape[2] == self.height, "expected %d but found %d" % ( + self.height, + inputs.shape[2], + ) + assert inputs.shape[3] == self.width, "expected %d but found %d" % ( + self.width, + inputs.shape[3], + ) + + x = coordinates[:, 0] + y = coordinates[:, 1] + + x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2) + y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2) + + grid = torch.stack( + [ + self.width_grid[None, :, :] + x_shift[:, None, None], + self.height_grid[None, :, :] + y_shift[:, None, None], + ], + dim=3, + ) + + crop = torch.round(F.grid_sample(inputs, grid, align_corners=True)).squeeze(1) + return crop + + +class Flatten(nn.Module): + def forward(self, input): + return input.view(input.size(0), -1) diff --git a/nle/agent/polybeast_env.py b/nle/agent/polybeast_env.py new file mode 100644 index 000000000..7073b20b6 --- /dev/null +++ b/nle/agent/polybeast_env.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing as mp +import logging +import os +import threading +import time + +import torch + +import libtorchbeast + +from models import ENVS + + +logging.basicConfig( + format=( + "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" + ), + level=0, +) + + +# Helper functions for NethackEnv. +def _format_observation(obs): + obs = torch.from_numpy(obs) + return obs.view((1, 1) + obs.shape) # (...) -> (T,B,...). + + +def create_folders(flags): + # Creates some of the folders that would be created by the filewriter. + logdir = os.path.join(flags.savedir, "archives") + if not os.path.exists(logdir): + logging.info("Creating archive directory: %s" % logdir) + os.makedirs(logdir, exist_ok=True) + else: + logging.info("Found archive directory: %s" % logdir) + + +def create_env(flags, env_id=0, lock=threading.Lock()): + # commenting out these options for now because they use too much disk space + # archivefile = "nethack.%i.%%(pid)i.%%(time)s.zip" % env_id + # if flags.single_ttyrec and env_id != 0: + # archivefile = None + + # logdir = os.path.join(flags.savedir, "archives") + + with lock: + env_class = ENVS[flags.env] + kwargs = dict( + savedir=None, + archivefile=None, + character=flags.character, + max_episode_steps=flags.max_num_steps, + observation_keys=( + "glyphs", + "chars", + "colors", + "specials", + "blstats", + "message", + "tty_chars", + "tty_colors", + "tty_cursor", + "inv_glyphs", + "inv_strs", + "inv_letters", + "inv_oclasses", + ), + penalty_step=flags.penalty_step, + penalty_time=flags.penalty_time, + penalty_mode=flags.fn_penalty_step, + ) + if flags.env in ("staircase", "pet", "oracle"): + kwargs.update(reward_win=flags.reward_win, reward_lose=flags.reward_lose) + elif env_id == 0: # print warning once + print("Ignoring flags.reward_win and flags.reward_lose") + if flags.state_counter != "none": + kwargs.update(state_counter=flags.state_counter) + env = env_class(**kwargs) + if flags.seedspath is not None and len(flags.seedspath) > 0: + raise NotImplementedError("seedspath > 0 not implemented yet.") + + return env + + +def serve(flags, server_address, env_id): + env = lambda: create_env(flags, env_id) + server = libtorchbeast.Server(env, server_address=server_address) + server.run() + + +def main(flags): + if flags.num_seeds > 0: + raise NotImplementedError("num_seeds > 0 not currently implemented.") + + create_folders(flags) + + if not flags.pipes_basename.startswith("unix:"): + raise Exception("--pipes_basename has to be of the form unix:/some/path.") + + processes = [] + for i in range(flags.num_servers): + p = mp.Process( + target=serve, args=(flags, f"{flags.pipes_basename}.{i}", i), daemon=True + ) + p.start() + processes.append(p) + + try: + # We are only here to listen to the interrupt. + while True: + time.sleep(10) + except KeyboardInterrupt: + pass diff --git a/nle/agent/polybeast_learner.py b/nle/agent/polybeast_learner.py new file mode 100644 index 000000000..ca79d451c --- /dev/null +++ b/nle/agent/polybeast_learner.py @@ -0,0 +1,517 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Run with OMP_NUM_THREADS=1. +# + +import collections +import logging +import os +import threading +import time +import timeit +import traceback + +import wandb +import omegaconf +import nest +import torch + +import libtorchbeast + +from core import file_writer +from core import vtrace + +from models import create_model +from models.baseline import NetHackNet + +from torch import nn +from torch.nn import functional as F + + +logging.basicConfig( + format=( + "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" + ), + level=0, +) + + +def compute_baseline_loss(advantages): + return 0.5 * torch.sum(advantages ** 2) + + +def compute_entropy_loss(logits): + policy = F.softmax(logits, dim=-1) + log_policy = F.log_softmax(logits, dim=-1) + entropy_per_timestep = torch.sum(-policy * log_policy, dim=-1) + return -torch.sum(entropy_per_timestep) + + +def compute_policy_gradient_loss(logits, actions, advantages): + cross_entropy = F.nll_loss( + F.log_softmax(torch.flatten(logits, 0, 1), dim=-1), + target=torch.flatten(actions, 0, 1), + reduction="none", + ) + cross_entropy = cross_entropy.view_as(advantages) + policy_gradient_loss_per_timestep = cross_entropy * advantages.detach() + return torch.sum(policy_gradient_loss_per_timestep) + + +def inference( + inference_batcher, model, flags, actor_device, lock=threading.Lock() +): # noqa: B008 + with torch.no_grad(): + for batch in inference_batcher: + batched_env_outputs, agent_state = batch.get_inputs() + observation, reward, done, *_ = batched_env_outputs + # Observation is a dict with keys 'features' and 'glyphs'. + observation["done"] = done + observation, agent_state = nest.map( + lambda t: t.to(actor_device, non_blocking=True), + (observation, agent_state), + ) + with lock: + outputs = model(observation, agent_state) + core_outputs, agent_state = nest.map(lambda t: t.cpu(), outputs) + # Restructuring the output in the way that is expected + # by the functions in actorpool. + outputs = ( + tuple( + ( + core_outputs["action"], + core_outputs["policy_logits"], + core_outputs["baseline"], + ) + ), + agent_state, + ) + batch.set_outputs(outputs) + + +# TODO(heiner): Given that our nest implementation doesn't support +# namedtuples, using them here doesn't seem like a good fit. We +# probably want to nestify the environment server and deal with +# dictionaries? +EnvOutput = collections.namedtuple( + "EnvOutput", "frame rewards done episode_step episode_return" +) +AgentOutput = NetHackNet.AgentOutput +Batch = collections.namedtuple("Batch", "env agent") + + +def learn( + learner_queue, + model, + actor_model, + optimizer, + scheduler, + stats, + flags, + plogger, + learner_device, + lock=threading.Lock(), # noqa: B008 +): + for tensors in learner_queue: + tensors = nest.map(lambda t: t.to(learner_device), tensors) + + batch, initial_agent_state = tensors + env_outputs, actor_outputs = batch + observation, reward, done, *_ = env_outputs + observation["reward"] = reward + observation["done"] = done + + lock.acquire() # Only one thread learning at a time. + + output, _ = model(observation, initial_agent_state, learning=True) + + # Use last baseline value (from the value function) to bootstrap. + learner_outputs = AgentOutput._make( + (output["action"], output["policy_logits"], output["baseline"]) + ) + + # At this point, the environment outputs at time step `t` are the inputs + # that lead to the learner_outputs at time step `t`. After the following + # shifting, the actions in `batch` and `learner_outputs` at time + # step `t` is what leads to the environment outputs at time step `t`. + batch = nest.map(lambda t: t[1:], batch) + learner_outputs = nest.map(lambda t: t[:-1], learner_outputs) + + # Turn into namedtuples again. + env_outputs, actor_outputs = batch + # Note that the env_outputs.frame is now a dict with 'features' and 'glyphs' + # instead of actually being the frame itself. This is currently not a problem + # because we never use actor_outputs.frame in the rest of this function. + env_outputs = EnvOutput._make(env_outputs) + actor_outputs = AgentOutput._make(actor_outputs) + learner_outputs = AgentOutput._make(learner_outputs) + + rewards = env_outputs.rewards + if flags.normalize_reward: + model.update_running_moments(rewards) + rewards /= model.get_running_std() + + total_loss = 0 + + # STANDARD EXTRINSIC LOSSES / REWARDS + if flags.entropy_cost > 0: + entropy_loss = flags.entropy_cost * compute_entropy_loss( + learner_outputs.policy_logits + ) + total_loss += entropy_loss + + discounts = (~env_outputs.done).float() * flags.discounting + + # This could be in C++. In TF, this is actually slower on the GPU. + vtrace_returns = vtrace.from_logits( + behavior_policy_logits=actor_outputs.policy_logits, + target_policy_logits=learner_outputs.policy_logits, + actions=actor_outputs.action, + discounts=discounts, + rewards=rewards, + values=learner_outputs.baseline, + bootstrap_value=learner_outputs.baseline[-1], + ) + + # Compute loss as a weighted sum of the baseline loss, the policy + # gradient loss and an entropy regularization term. + pg_loss = compute_policy_gradient_loss( + learner_outputs.policy_logits, + actor_outputs.action, + vtrace_returns.pg_advantages, + ) + baseline_loss = flags.baseline_cost * compute_baseline_loss( + vtrace_returns.vs - learner_outputs.baseline + ) + total_loss += pg_loss + baseline_loss + + # BACKWARD STEP + optimizer.zero_grad() + total_loss.backward() + if flags.grad_norm_clipping > 0: + nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping) + optimizer.step() + scheduler.step() + + actor_model.load_state_dict(model.state_dict()) + + # LOGGING + episode_returns = env_outputs.episode_return[env_outputs.done] + stats["step"] = stats.get("step", 0) + flags.unroll_length * flags.batch_size + stats["mean_episode_return"] = torch.mean(episode_returns).item() + stats["mean_episode_step"] = torch.mean(env_outputs.episode_step.float()).item() + stats["total_loss"] = total_loss.item() + stats["pg_loss"] = pg_loss.item() + stats["baseline_loss"] = baseline_loss.item() + if flags.entropy_cost > 0: + stats["entropy_loss"] = entropy_loss.item() + + stats["learner_queue_size"] = learner_queue.size() + + if not len(episode_returns): + # Hide the mean-of-empty-tuple NaN as it scares people. + stats["mean_episode_return"] = None + + # Only logging if at least one episode was finished + if len(episode_returns): + # TODO: log also SPS + plogger.log(stats) + if flags.wandb: + wandb.log(stats, step=stats["step"]) + + lock.release() + + +def train(flags): + logging.info("Logging results to %s", flags.savedir) + if isinstance(flags, omegaconf.DictConfig): + flag_dict = omegaconf.OmegaConf.to_container(flags) + else: + flag_dict = vars(flags) + plogger = file_writer.FileWriter(xp_args=flag_dict, rootdir=flags.savedir) + + if not flags.disable_cuda and torch.cuda.is_available(): + logging.info("Using CUDA.") + learner_device = torch.device(flags.learner_device) + actor_device = torch.device(flags.actor_device) + else: + logging.info("Not using CUDA.") + learner_device = torch.device("cpu") + actor_device = torch.device("cpu") + + if flags.max_learner_queue_size is None: + flags.max_learner_queue_size = flags.batch_size + + # The queue the learner threads will get their data from. + # Setting `minimum_batch_size == maximum_batch_size` + # makes the batch size static. We could make it dynamic, but that + # requires a loss (and learning rate schedule) that's batch size + # independent. + learner_queue = libtorchbeast.BatchingQueue( + batch_dim=1, + minimum_batch_size=flags.batch_size, + maximum_batch_size=flags.batch_size, + check_inputs=True, + maximum_queue_size=flags.max_learner_queue_size, + ) + + # The "batcher", a queue for the inference call. Will yield + # "batch" objects with `get_inputs` and `set_outputs` methods. + # The batch size of the tensors will be dynamic. + inference_batcher = libtorchbeast.DynamicBatcher( + batch_dim=1, + minimum_batch_size=1, + maximum_batch_size=512, + timeout_ms=100, + check_outputs=True, + ) + + addresses = [] + connections_per_server = 1 + pipe_id = 0 + while len(addresses) < flags.num_actors: + for _ in range(connections_per_server): + addresses.append(f"{flags.pipes_basename}.{pipe_id}") + if len(addresses) == flags.num_actors: + break + pipe_id += 1 + + logging.info("Using model %s", flags.model) + + model = create_model(flags, learner_device) + + plogger.metadata["model_numel"] = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) + + logging.info("Number of model parameters: %i", plogger.metadata["model_numel"]) + + actor_model = create_model(flags, actor_device) + + # The ActorPool that will run `flags.num_actors` many loops. + actors = libtorchbeast.ActorPool( + unroll_length=flags.unroll_length, + learner_queue=learner_queue, + inference_batcher=inference_batcher, + env_server_addresses=addresses, + initial_agent_state=model.initial_state(), + ) + + def run(): + try: + actors.run() + except Exception as e: + logging.error("Exception in actorpool thread!") + traceback.print_exc() + print() + raise e + + actorpool_thread = threading.Thread(target=run, name="actorpool-thread") + + optimizer = torch.optim.RMSprop( + model.parameters(), + lr=flags.learning_rate, + momentum=flags.momentum, + eps=flags.epsilon, + alpha=flags.alpha, + ) + + def lr_lambda(epoch): + return ( + 1 + - min(epoch * flags.unroll_length * flags.batch_size, flags.total_steps) + / flags.total_steps + ) + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) + + stats = {} + + if flags.checkpoint and os.path.exists(flags.checkpoint): + logging.info("Loading checkpoint: %s" % flags.checkpoint) + checkpoint_states = torch.load( + flags.checkpoint, map_location=flags.learner_device + ) + model.load_state_dict(checkpoint_states["model_state_dict"]) + optimizer.load_state_dict(checkpoint_states["optimizer_state_dict"]) + scheduler.load_state_dict(checkpoint_states["scheduler_state_dict"]) + stats = checkpoint_states["stats"] + logging.info(f"Resuming preempted job, current stats:\n{stats}") + + # Initialize actor model like learner model. + actor_model.load_state_dict(model.state_dict()) + + learner_threads = [ + threading.Thread( + target=learn, + name="learner-thread-%i" % i, + args=( + learner_queue, + model, + actor_model, + optimizer, + scheduler, + stats, + flags, + plogger, + learner_device, + ), + ) + for i in range(flags.num_learner_threads) + ] + inference_threads = [ + threading.Thread( + target=inference, + name="inference-thread-%i" % i, + args=(inference_batcher, actor_model, flags, actor_device), + ) + for i in range(flags.num_inference_threads) + ] + + actorpool_thread.start() + for t in learner_threads + inference_threads: + t.start() + + def checkpoint(checkpoint_path=None): + if flags.checkpoint: + if checkpoint_path is None: + checkpoint_path = flags.checkpoint + logging.info("Saving checkpoint to %s", checkpoint_path) + torch.save( + { + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "stats": stats, + "flags": vars(flags), + }, + checkpoint_path, + ) + + def format_value(x): + return f"{x:1.5}" if isinstance(x, float) else str(x) + + try: + train_start_time = timeit.default_timer() + train_time_offset = stats.get("train_seconds", 0) # used for resuming training + last_checkpoint_time = timeit.default_timer() + + dev_checkpoint_intervals = [0, 0.25, 0.5, 0.75] + + loop_start_time = timeit.default_timer() + loop_start_step = stats.get("step", 0) + while True: + if loop_start_step >= flags.total_steps: + break + time.sleep(5) + loop_end_time = timeit.default_timer() + loop_end_step = stats.get("step", 0) + + stats["train_seconds"] = round( + loop_end_time - train_start_time + train_time_offset, 1 + ) + + if loop_end_time - last_checkpoint_time > 10 * 60: + # Save every 10 min. + checkpoint() + last_checkpoint_time = loop_end_time + + if len(dev_checkpoint_intervals) > 0: + step_percentage = loop_end_step / flags.total_steps + i = dev_checkpoint_intervals[0] + if step_percentage > i: + checkpoint(flags.checkpoint[:-4] + "_" + str(i) + ".tar") + dev_checkpoint_intervals = dev_checkpoint_intervals[1:] + + logging.info( + "Step %i @ %.1f SPS. Inference batcher size: %i." + " Learner queue size: %i." + " Other stats: (%s)", + loop_end_step, + (loop_end_step - loop_start_step) / (loop_end_time - loop_start_time), + inference_batcher.size(), + learner_queue.size(), + ", ".join( + f"{key} = {format_value(value)}" for key, value in stats.items() + ), + ) + loop_start_time = loop_end_time + loop_start_step = loop_end_step + except KeyboardInterrupt: + pass # Close properly. + else: + logging.info("Learning finished after %i steps.", stats["step"]) + + checkpoint() + + # Done with learning. Let's stop all the ongoing work. + inference_batcher.close() + learner_queue.close() + + actorpool_thread.join() + + for t in learner_threads + inference_threads: + t.join() + + +def test(flags): + test_checkpoint = os.path.join(flags.savedir, "test_checkpoint.tar") + checkpoint = os.path.join(flags.load_dir, "checkpoint.tar") + if not os.path.exists(os.path.dirname(test_checkpoint)): + os.makedirs(os.path.dirname(test_checkpoint)) + + logging.info("Creating test copy of checkpoint '%s'", checkpoint) + + checkpoint = torch.load(checkpoint) + for d in checkpoint["optimizer_state_dict"]["param_groups"]: + d["lr"] = 0.0 + d["initial_lr"] = 0.0 + + checkpoint["scheduler_state_dict"]["last_epoch"] = 0 + checkpoint["scheduler_state_dict"]["_step_count"] = 0 + checkpoint["scheduler_state_dict"]["base_lrs"] = [0.0] + checkpoint["stats"]["step"] = 0 + checkpoint["stats"]["_tick"] = 0 + + flags.checkpoint = test_checkpoint + flags.learning_rate = 0.0 + + logging.info("Saving test checkpoint to %s", test_checkpoint) + torch.save(checkpoint, test_checkpoint) + + train(flags) + + +def main(flags): + if flags.wandb: + wandb.init( + project=flags.project, + config=vars(flags), + group=flags.group, + entity=flags.entity, + ) + if flags.mode == "train": + if flags.write_profiler_trace: + logging.info("Running with profiler.") + with torch.autograd.profiler.profile() as prof: + train(flags) + filename = "chrome-%s.trace" % time.strftime("%Y%m%d-%H%M%S") + logging.info("Writing profiler trace to '%s.gz'", filename) + prof.export_chrome_trace(filename) + os.system("gzip %s" % filename) + else: + train(flags) + elif flags.mode.startswith("test"): + test(flags) diff --git a/nle/agent/polyhydra.py b/nle/agent/polyhydra.py new file mode 100644 index 000000000..1554574a4 --- /dev/null +++ b/nle/agent/polyhydra.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Installation for hydra: +pip install hydra-core hydra_colorlog --upgrade + +Runs like polybeast but use = to set flags: +python -m polyhydra.py learning_rate=0.001 rnd.twoheaded=true + +Run sweep with another -m after the module: +python -m polyhydra.py -m learning_rate=0.01,0.001,0.0001,0.00001 momentum=0,0.5 + +Baseline should run with: +python polyhydra.py +""" + +from pathlib import Path +import logging +import os +import multiprocessing as mp + +import hydra +import numpy as np +from omegaconf import OmegaConf, DictConfig + +import torch + +import polybeast_env +import polybeast_learner + +if torch.__version__.startswith("1.5") or torch.__version__.startswith("1.6"): + # pytorch 1.5.* needs this for some reason on the cluster + os.environ["MKL_SERVICE_FORCE_INTEL"] = "1" +os.environ["OMP_NUM_THREADS"] = "1" + +logging.basicConfig( + format=( + "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s" + ), + level=0, +) + + +def pipes_basename(): + logdir = Path(os.getcwd()) + name = ".".join([logdir.parents[1].name, logdir.parents[0].name, logdir.name]) + return "unix:/tmp/poly.%s" % name + + +def get_common_flags(flags): + flags = OmegaConf.to_container(flags) + flags["pipes_basename"] = pipes_basename() + flags["savedir"] = os.getcwd() + return OmegaConf.create(flags) + + +def get_learner_flags(flags): + lrn_flags = OmegaConf.to_container(flags) + lrn_flags["checkpoint"] = os.path.join(flags["savedir"], "checkpoint.tar") + lrn_flags["entropy_cost"] = float(lrn_flags["entropy_cost"]) + return OmegaConf.create(lrn_flags) + + +def run_learner(flags: DictConfig): + polybeast_learner.main(flags) + + +def get_environment_flags(flags): + env_flags = OmegaConf.to_container(flags) + env_flags["num_servers"] = flags.num_actors + max_num_steps = 1e6 + if flags.env in ("staircase", "pet"): + max_num_steps = 1000 + env_flags["max_num_steps"] = int(max_num_steps) + env_flags["seedspath"] = "" + return OmegaConf.create(env_flags) + + +def run_env(flags): + np.random.seed() # Get new random seed in forked process. + polybeast_env.main(flags) + + +def symlink_latest(savedir, symlink): + try: + if os.path.islink(symlink): + os.remove(symlink) + if not os.path.exists(symlink): + os.symlink(savedir, symlink) + logging.info("Symlinked log directory: %s" % symlink) + except OSError: + # os.remove() or os.symlink() raced. Don't do anything. + pass + + +@hydra.main(config_name="config") +def main(flags: DictConfig): + if os.path.exists("config.yaml"): + # this ignores the local config.yaml and replaces it completely with saved one + logging.info("loading existing configuration, we're continuing a previous run") + new_flags = OmegaConf.load("config.yaml") + cli_conf = OmegaConf.from_cli() + # however, you can override parameters from the cli still + # this is useful e.g. if you did total_steps=N before and want to increase it + flags = OmegaConf.merge(new_flags, cli_conf) + if flags.load_dir and os.path.exists(os.path.join(flags.load_dir, "config.yaml")): + new_flags = OmegaConf.load(os.path.join(flags.load_dir, "config.yaml")) + cli_conf = OmegaConf.from_cli() + flags = OmegaConf.merge(new_flags, cli_conf) + + logging.info(flags.pretty(resolve=True)) + OmegaConf.save(flags, "config.yaml") + + flags = get_common_flags(flags) + + # set flags for polybeast_env + env_flags = get_environment_flags(flags) + env_processes = [] + for _ in range(1): + p = mp.Process(target=run_env, args=(env_flags,)) + p.start() + env_processes.append(p) + + symlink_latest( + flags.savedir, os.path.join(hydra.utils.get_original_cwd(), "latest") + ) + + lrn_flags = get_learner_flags(flags) + run_learner(lrn_flags) + + for p in env_processes: + p.kill() + p.join() + + +if __name__ == "__main__": + main() diff --git a/nle/agent/requirements.txt b/nle/agent/requirements.txt new file mode 100644 index 000000000..468f99c07 --- /dev/null +++ b/nle/agent/requirements.txt @@ -0,0 +1,7 @@ +nle +hydra-core +hydra_colorlog +wandb +einops +torch +numpy diff --git a/nle/agent/util/__init__.py b/nle/agent/util/__init__.py new file mode 100644 index 000000000..8daf2005d --- /dev/null +++ b/nle/agent/util/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nle/agent/util/id_pairs.py b/nle/agent/util/id_pairs.py new file mode 100644 index 000000000..30352401f --- /dev/null +++ b/nle/agent/util/id_pairs.py @@ -0,0 +1,142 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum + +import numpy as np + +from nle.nethack import * # noqa: F403 + +# flake8: noqa: F405 + +# TODO: import this from NLE again +NUM_OBJECTS = 453 +MAXEXPCHARS = 9 + + +class GlyphGroup(enum.IntEnum): + # See display.h in NetHack. + MON = 0 + PET = 1 + INVIS = 2 + DETECT = 3 + BODY = 4 + RIDDEN = 5 + OBJ = 6 + CMAP = 7 + EXPLODE = 8 + ZAP = 9 + SWALLOW = 10 + WARNING = 11 + STATUE = 12 + + +def id_pairs_table(): + """Returns a lookup table for glyph -> NLE id pairs.""" + table = np.zeros([MAX_GLYPH, 2], dtype=np.int16) + + num_nle_ids = 0 + + for glyph in range(GLYPH_MON_OFF, GLYPH_PET_OFF): + table[glyph] = (glyph, GlyphGroup.MON) + num_nle_ids += 1 + + for glyph in range(GLYPH_PET_OFF, GLYPH_INVIS_OFF): + table[glyph] = (glyph - GLYPH_PET_OFF, GlyphGroup.PET) + + for glyph in range(GLYPH_INVIS_OFF, GLYPH_DETECT_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.INVIS) + num_nle_ids += 1 + + for glyph in range(GLYPH_DETECT_OFF, GLYPH_BODY_OFF): + table[glyph] = (glyph - GLYPH_DETECT_OFF, GlyphGroup.DETECT) + + for glyph in range(GLYPH_BODY_OFF, GLYPH_RIDDEN_OFF): + table[glyph] = (glyph - GLYPH_BODY_OFF, GlyphGroup.BODY) + + for glyph in range(GLYPH_RIDDEN_OFF, GLYPH_OBJ_OFF): + table[glyph] = (glyph - GLYPH_RIDDEN_OFF, GlyphGroup.RIDDEN) + + for glyph in range(GLYPH_OBJ_OFF, GLYPH_CMAP_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.OBJ) + num_nle_ids += 1 + + for glyph in range(GLYPH_CMAP_OFF, GLYPH_EXPLODE_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.CMAP) + num_nle_ids += 1 + + for glyph in range(GLYPH_EXPLODE_OFF, GLYPH_ZAP_OFF): + id_ = num_nle_ids + (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS + table[glyph] = (id_, GlyphGroup.EXPLODE) + + num_nle_ids += EXPL_MAX + + for glyph in range(GLYPH_ZAP_OFF, GLYPH_SWALLOW_OFF): + id_ = num_nle_ids + (glyph - GLYPH_ZAP_OFF) // 4 + table[glyph] = (id_, GlyphGroup.ZAP) + + num_nle_ids += NUM_ZAP + + for glyph in range(GLYPH_SWALLOW_OFF, GLYPH_WARNING_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.SWALLOW) + num_nle_ids += 1 + + for glyph in range(GLYPH_WARNING_OFF, GLYPH_STATUE_OFF): + table[glyph] = (num_nle_ids, GlyphGroup.WARNING) + num_nle_ids += 1 + + for glyph in range(GLYPH_STATUE_OFF, MAX_GLYPH): + table[glyph] = (glyph - GLYPH_STATUE_OFF, GlyphGroup.STATUE) + + return table + + +def id_pairs_func(glyph): + result = glyph_to_mon(glyph) + if result != NO_GLYPH: + return result + if glyph_is_invisible(glyph): + return NUMMONS + if glyph_is_body(glyph): + return glyph - GLYPH_BODY_OFF + + offset = NUMMONS + 1 + + # CORPSE handled by glyph_is_body; STATUE handled by glyph_to_mon. + result = glyph_to_obj(glyph) + if result != NO_GLYPH: + return result + offset + offset += NUM_OBJECTS + + # I don't understand glyph_to_cmap and/or the GLYPH_EXPLODE_OFF definition + # with MAXPCHARS - MAXEXPCHARS. + if GLYPH_CMAP_OFF <= glyph < GLYPH_EXPLODE_OFF: + return glyph - GLYPH_CMAP_OFF + offset + offset += MAXPCHARS - MAXEXPCHARS + + if GLYPH_EXPLODE_OFF <= glyph < GLYPH_ZAP_OFF: + return (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS + offset + offset += EXPL_MAX + + if GLYPH_ZAP_OFF <= glyph < GLYPH_SWALLOW_OFF: + return ((glyph - GLYPH_ZAP_OFF) >> 2) + offset + offset += NUM_ZAP + + if GLYPH_SWALLOW_OFF <= glyph < GLYPH_WARNING_OFF: + return offset + offset += 1 + + result = glyph_to_warning(glyph) + if result != NO_GLYPH: + return result + offset