diff --git a/README.md b/README.md index f2d3f0a9a..b1f8e8884 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,8 @@ Following useful techniques have been also implemented in PFRL: - examples: [[Rainbow]](examples/atari/reproduction/rainbow) [[DQN/DoubleDQN/PAL]](examples/atari/train_dqn_ale.py) - [Prioritized Experience Replay](https://arxiv.org/abs/1511.05952) - examples: [[Rainbow]](examples/atari/reproduction/rainbow) [[DQN/DoubleDQN/PAL]](examples/atari/train_dqn_ale.py) +- [Hindsight Experience Replay](https://arxiv.org/abs/1707.01495) + - examples: [[Bit-flip DQN]](examples/her/train_dqn_bit_flip.py) [[DDPG on Fetch Envs]](examples/her/train_ddpg_her_fetch.py) - [Dueling Network](https://arxiv.org/abs/1511.06581) - examples: [[Rainbow]](examples/atari/reproduction/rainbow) [[DQN/DoubleDQN/PAL]](examples/atari/train_dqn_ale.py) - [Normalized Advantage Function](https://arxiv.org/abs/1603.00748) diff --git a/examples/her/README.md b/examples/her/README.md new file mode 100644 index 000000000..28113c9ec --- /dev/null +++ b/examples/her/README.md @@ -0,0 +1,23 @@ +# Hindsight Experience Replay +These two examples train agents using [Hindsight Experience Replay (HER)](https://arxiv.org/abs/1707.01495). The first example, `train_dqn_bit_flip.py` trains a DoubleDQN in the BitFlip environment as described in the HER paper. The second example, `train_ddpg_her_fetch.py` trains agents in the robotic Fetch environments, also described in the HER paper. + +## To Run: + +To run the bitflip example: +``` +python train_dqn_bit_flip.py --num-bits +``` + +To run DDPG with HER on fetch tasks, run: +``` +python train_ddpg_her_fetch.py --env +``` + +Options +- `--gpu`: Set to -1 if you have no GPU. + +## Results and Reproducibility +The BitFlip environment was implemented as per the description in the paper. The DQN algorithm for the bitflip environment is not from the paper (to our knowledge there is no publicly released implementation). + +For the Fetch environments, we added an action penalty, return clipping, and observation normalization to DDPG as done by the [OpenAI baselines implementation](https://github.com/openai/baselines/tree/master/baselines/her). + diff --git a/examples/her/train_ddpg_her_fetch.py b/examples/her/train_ddpg_her_fetch.py new file mode 100644 index 000000000..d0a825aa1 --- /dev/null +++ b/examples/her/train_ddpg_her_fetch.py @@ -0,0 +1,325 @@ +import argparse + +import gym +import gym.spaces +import numpy as np +import torch +import torch.nn as nn + +import pfrl +from pfrl import experiments, replay_buffers, utils +from pfrl.nn import BoundByTanh, ConcatObsAndAction +from pfrl.policies import DeterministicHead + + +class ComputeSuccessRate(gym.Wrapper): + """Environment wrapper that computes success rate. + + Args: + env: Env to wrap + + Attributes: + success_record: list of successes + """ + + def __init__(self, env): + super().__init__(env) + self.success_record = [] + + def reset(self): + self.success_record.append(None) + return self.env.reset() + + def step(self, action): + obs, r, done, info = self.env.step(action) + assert "is_success" in info + self.success_record[-1] = info["is_success"] + return obs, r, done, info + + def get_statistics(self): + # Ignore episodes with zero step + valid_record = [x for x in self.success_record if x is not None] + success_rate = ( + valid_record.count(True) / len(valid_record) if valid_record else np.nan + ) + return [("success_rate", success_rate)] + + def clear_statistics(self): + self.success_record = [] + + +class ClipObservation(gym.ObservationWrapper): + """Clip observations to a given range. + + Args: + env: Env to wrap. + low: Lower limit. + high: Upper limit. + + Attributes: + original_observation: Observation before casting. + """ + + def __init__(self, env, low, high): + super().__init__(env) + self.low = low + self.high = high + + def observation(self, observation): + self.original_observation = observation + return np.clip(observation, self.low, self.high) + + +class EpsilonGreedyWithGaussianNoise(pfrl.explorer.Explorer): + """Epsilon-Greedy with Gaussian noise. + + This type of explorer was used in + https://github.com/openai/baselines/tree/master/baselines/her + """ + + def __init__(self, epsilon, random_action_func, noise_scale, low=None, high=None): + self.epsilon = epsilon + self.random_action_func = random_action_func + self.noise_scale = noise_scale + self.low = low + self.high = high + + def select_action(self, t, greedy_action_func, action_value=None): + if np.random.rand() < self.epsilon: + a = self.random_action_func() + else: + a = greedy_action_func() + noise = np.random.normal(scale=self.noise_scale, size=a.shape).astype( + np.float32 + ) + a = a + noise + if self.low is not None or self.high is not None: + return np.clip(a, self.low, self.high) + else: + return a + + def __repr__(self): + return ( + "EpsilonGreedyWithGaussianNoise(epsilon={}, noise_scale={}, low={}," + " high={})".format(self.epsilon, self.noise_scale, self.low, self.high) + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--outdir", + type=str, + default="results", + help=( + "Directory path to save output files." + " If it does not exist, it will be created." + ), + ) + parser.add_argument( + "--env", + type=str, + default="FetchReach-v1", + help="OpenAI Gym MuJoCo env to perform algorithm on.", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)") + parser.add_argument( + "--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU." + ) + parser.add_argument("--demo", action="store_true", default=False) + parser.add_argument("--load", type=str, default=None) + parser.add_argument( + "--log-level", + type=int, + default=20, + help="Logging level. 10:DEBUG, 20:INFO etc.", + ) + parser.add_argument( + "--steps", + type=int, + default=5 * 10 ** 3, + help="Total number of timesteps to train the agent.", + ) + parser.add_argument( + "--replay-start-size", + type=int, + default=5 * 10 ** 2, + help="Minimum replay buffer size before performing gradient updates.", + ) + parser.add_argument( + "--replay-strategy", + default="future", + choices=["future", "final"], + help="The replay strategy to use", + ) + parser.add_argument( + "--no-hindsight", + action="store_true", + default=False, + help="Do not use Hindsight Replay", + ) + parser.add_argument("--eval-n-episodes", type=int, default=10) + parser.add_argument("--eval-interval", type=int, default=500) + parser.add_argument( + "--render", action="store_true", help="Render env states in a GUI window." + ) + args = parser.parse_args() + + import logging + + logging.basicConfig(level=args.log_level) + + # Set a random seed used in PFRL. + utils.set_random_seed(args.seed) + + args.outdir = experiments.prepare_output_dir(args, args.outdir) + print("Output files are saved in {}".format(args.outdir)) + + def make_env(test): + env = gym.make(args.env) + # Unwrap TimeLimit wrapper + assert isinstance(env, gym.wrappers.TimeLimit) + env = env.env + # Use different random seeds for train and test envs + env_seed = 2 ** 32 - 1 - args.seed if test else args.seed + env.seed(env_seed) + # Cast observations to float32 because our model uses float32 + if args.render and not test: + env = pfrl.wrappers.Render(env) + env = ComputeSuccessRate(env) + return env + + env = make_env(test=False) + timestep_limit = env.spec.max_episode_steps + obs_space = env.observation_space + action_space = env.action_space + print("Observation space:", obs_space) + print("Action space:", action_space) + + assert isinstance(obs_space, gym.spaces.Dict) + obs_size = obs_space["observation"].low.size + obs_space["desired_goal"].low.size + action_size = action_space.low.size + + def reward_fn(dg, ag): + return env.compute_reward(ag, dg, None) + + q_func = nn.Sequential( + ConcatObsAndAction(), + nn.Linear(obs_size + action_size, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 1), + ) + policy = nn.Sequential( + nn.Linear(obs_size, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, 256), + nn.ReLU(), + nn.Linear(256, action_size), + BoundByTanh(low=action_space.low, high=action_space.high), + DeterministicHead(), + ) + + def init_xavier_uniform(layer): + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform_(layer.weight) + nn.init.zeros_(layer.bias) + + with torch.no_grad(): + q_func.apply(init_xavier_uniform) + policy.apply(init_xavier_uniform) + + opt_a = torch.optim.Adam(policy.parameters()) + opt_c = torch.optim.Adam(q_func.parameters()) + + if args.replay_strategy == "future": + replay_strategy = replay_buffers.hindsight.ReplayFutureGoal() + else: + replay_strategy = replay_buffers.hindsight.ReplayFinalGoal() + rbuf = replay_buffers.hindsight.HindsightReplayBuffer( + reward_fn=reward_fn, + replay_strategy=replay_strategy, + capacity=10 ** 6, + ) + + explorer = EpsilonGreedyWithGaussianNoise( + epsilon=0.3, + random_action_func=lambda: env.action_space.sample(), + noise_scale=0.2, + ) + + # Normalize observations based on their empirical mean and variance + obs_normalizer = pfrl.nn.EmpiricalNormalization(obs_size, clip_threshold=5) + + def phi(observation): + # Feature extractor + obs = np.asarray(observation["observation"], dtype=np.float32) + dg = np.asarray(observation["desired_goal"], dtype=np.float32) + return np.concatenate((obs, dg)).clip(-200, 200) + + # 1 epoch = 10 episodes = 500 steps + gamma = 1.0 - 1.0 / timestep_limit + agent = pfrl.agents.DDPG( + policy, + q_func, + opt_a, + opt_c, + rbuf, + phi=phi, + gamma=gamma, + explorer=explorer, + replay_start_size=256, + target_update_method="soft", + target_update_interval=50, + update_interval=50, + soft_update_tau=5e-2, + n_times_update=40, + gpu=args.gpu, + minibatch_size=256, + clip_return_range=(-1.0 / (1.0 - gamma), 0.0), + action_l2_penalty_coef=1.0, + obs_normalizer=obs_normalizer, + ) + + if args.load: + agent.load(args.load) + + eval_env = make_env(test=True) + if args.demo: + eval_stats = experiments.eval_performance( + env=eval_env, + agent=agent, + n_steps=args.eval_n_steps, + n_episodes=None, + max_episode_len=timestep_limit, + ) + print( + "n_episodes: {} mean: {} median: {} stdev {}".format( + eval_stats["episodes"], + eval_stats["mean"], + eval_stats["median"], + eval_stats["stdev"], + ) + ) + else: + experiments.train_agent_with_evaluation( + agent=agent, + env=env, + steps=args.steps, + eval_n_steps=None, + eval_n_episodes=args.eval_n_episodes, + eval_interval=args.eval_interval, + outdir=args.outdir, + save_best_so_far_agent=True, + eval_env=eval_env, + train_max_episode_len=timestep_limit, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/her/train_dqn_bit_flip.py b/examples/her/train_dqn_bit_flip.py new file mode 100644 index 000000000..e640986a0 --- /dev/null +++ b/examples/her/train_dqn_bit_flip.py @@ -0,0 +1,252 @@ +import argparse + +import gym +import gym.spaces as spaces +import numpy as np +import torch +import torch.nn as nn + +from pfrl import agents, experiments, explorers, replay_buffers, utils +from pfrl.initializers import init_chainer_default +from pfrl.q_functions import DiscreteActionValueHead + + +def reward_fn(dg, ag): + return -1.0 if (ag != dg).any() else 0.0 + + +class BitFlip(gym.GoalEnv): + """BitFlip environment from https://arxiv.org/pdf/1707.01495.pdf + + Args: + n: State space is {0,1}^n + """ + + def __init__(self, n): + self.n = n + self.action_space = spaces.Discrete(n) + self.observation_space = spaces.Dict( + dict( + desired_goal=spaces.MultiBinary(n), + achieved_goal=spaces.MultiBinary(n), + observation=spaces.MultiBinary(n), + ) + ) + self.clear_statistics() + + def compute_reward(self, achieved_goal, desired_goal, info): + return reward_fn(desired_goal, achieved_goal) + + def _check_done(self): + success = ( + self.observation["desired_goal"] == self.observation["achieved_goal"] + ).all() + return (self.steps >= self.n) or success, success + + def step(self, action): + # Compute action outcome + bit_new = int(not self.observation["observation"][action]) + new_obs = self.observation["observation"].copy() + new_obs[action] = bit_new + # Set new observation + dg = self.observation["desired_goal"] + self.observation = { + "desired_goal": dg.copy(), + "achieved_goal": new_obs, + "observation": new_obs, + } + + reward = self.compute_reward( + self.observation["achieved_goal"], self.observation["desired_goal"], {} + ) + self.steps += 1 + done, success = self._check_done() + assert success == (reward == 0) + if done: + result = 1 if success else 0 + self.results.append(result) + return self.observation, reward, done, {} + + def reset(self): + sample_obs = self.observation_space.sample() + state, goal = sample_obs["observation"], sample_obs["desired_goal"] + while (state == goal).all(): + sample_obs = self.observation_space.sample() + state, goal = sample_obs["observation"], sample_obs["desired_goal"] + self.observation = dict() + self.observation["desired_goal"] = goal + self.observation["achieved_goal"] = state + self.observation["observation"] = state + self.steps = 0 + return self.observation + + def get_statistics(self): + failures = self.results.count(0) + successes = self.results.count(1) + assert len(self.results) == failures + successes + if not self.results: + return [("success_rate", None)] + success_rate = successes / float(len(self.results)) + return [("success_rate", success_rate)] + + def clear_statistics(self): + self.results = [] + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--outdir", + type=str, + default="results", + help=( + "Directory path to save output files." + " If it does not exist, it will be created." + ), + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 31)") + parser.add_argument( + "--gpu", type=int, default=0, help="GPU to use, set to -1 if no GPU." + ) + parser.add_argument("--demo", action="store_true", default=False) + parser.add_argument("--load", type=str, default=None) + parser.add_argument( + "--log-level", + type=int, + default=20, + help="Logging level. 10:DEBUG, 20:INFO etc.", + ) + parser.add_argument( + "--steps", + type=int, + default=5 * 10 ** 6, + help="Total number of timesteps to train the agent.", + ) + parser.add_argument( + "--replay-start-size", + type=int, + default=5 * 10 ** 2, + help="Minimum replay buffer size before performing gradient updates.", + ) + parser.add_argument( + "--num-bits", + type=int, + default=10, + help="Number of bits for BitFlipping environment.", + ) + parser.add_argument( + "--no-hindsight", + action="store_true", + default=False, + help="Do not use Hindsight Replay.", + ) + parser.add_argument("--eval-n-episodes", type=int, default=100) + parser.add_argument("--eval-interval", type=int, default=250000) + parser.add_argument("--n-best-episodes", type=int, default=100) + args = parser.parse_args() + + import logging + + logging.basicConfig(level=args.log_level) + + # Set a random seed used in PFRL. + utils.set_random_seed(args.seed) + + # Set different random seeds for train and test envs. + train_seed = args.seed + test_seed = 2 ** 31 - 1 - args.seed + + args.outdir = experiments.prepare_output_dir(args, args.outdir) + print("Output files are saved in {}".format(args.outdir)) + + def make_env(test): + # Use different random seeds for train and test envs + env_seed = test_seed if test else train_seed + env = BitFlip(args.num_bits) + env.seed(int(env_seed)) + return env + + env = make_env(test=False) + eval_env = make_env(test=True) + + n_actions = env.action_space.n + q_func = nn.Sequential( + init_chainer_default(nn.Linear(args.num_bits * 2, 256)), + nn.ReLU(), + init_chainer_default(nn.Linear(256, n_actions)), + DiscreteActionValueHead(), + ) + + opt = torch.optim.Adam(q_func.parameters(), eps=1e-4) + + if args.no_hindsight: + rbuf = replay_buffers.ReplayBuffer(10 ** 6) + else: + rbuf = replay_buffers.hindsight.HindsightReplayBuffer( + reward_fn=reward_fn, + replay_strategy=replay_buffers.hindsight.ReplayFutureGoal(), + capacity=10 ** 6, + ) + + decay_steps = (args.num_bits + 5) * 10 ** 3 + end_epsilon = min(0.1, 0.5 / args.num_bits) + explorer = explorers.LinearDecayEpsilonGreedy( + start_epsilon=0.5, + end_epsilon=end_epsilon, + decay_steps=decay_steps, + random_action_func=lambda: np.random.randint(n_actions), + ) + + def phi(observation): + # Feature extractor + obs = np.asarray(observation["observation"], dtype=np.float32) + dg = np.asarray(observation["desired_goal"], dtype=np.float32) + return np.concatenate((obs, dg)) + + Agent = agents.DoubleDQN + agent = Agent( + q_func, + opt, + rbuf, + gpu=args.gpu, + gamma=0.99, + explorer=explorer, + replay_start_size=args.replay_start_size, + target_update_interval=10 ** 3, + clip_delta=True, + update_interval=4, + batch_accumulator="sum", + phi=phi, + ) + + if args.load: + agent.load(args.load) + + if args.demo: + eval_stats = experiments.eval_performance( + env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_episodes + ) + print( + "n_episodes: {} mean: {} median: {} stdev {}".format( + eval_stats["episodes"], + eval_stats["mean"], + eval_stats["median"], + eval_stats["stdev"], + ) + ) + else: + experiments.train_agent_with_evaluation( + agent=agent, + env=env, + steps=args.steps, + eval_n_steps=None, + eval_n_episodes=args.eval_n_episodes, + eval_interval=args.eval_interval, + outdir=args.outdir, + save_best_so_far_agent=True, + eval_env=eval_env, + ) + + +if __name__ == "__main__": + main() diff --git a/examples_tests/her/test_dqn_bit_flip.sh b/examples_tests/her/test_dqn_bit_flip.sh new file mode 100755 index 000000000..d89208e63 --- /dev/null +++ b/examples_tests/her/test_dqn_bit_flip.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +set -Ceu + +outdir=$(mktemp -d) + +gpu="$1" + +# her/dqn_bit_flip +python examples/her/train_dqn_bit_flip.py --gpu $gpu --steps 100 --outdir $outdir/her/bit_flip +model=$(find $outdir/her/bit_flip -name "*_finish") +python examples/her/train_dqn_bit_flip.py --demo --load $model --eval-n-episodes 1 --outdir $outdir/temp --gpu $gpu \ No newline at end of file diff --git a/pfrl/agents/ddpg.py b/pfrl/agents/ddpg.py index 9d2d15589..e007f8489 100644 --- a/pfrl/agents/ddpg.py +++ b/pfrl/agents/ddpg.py @@ -80,13 +80,19 @@ def __init__( logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, + clip_return_range=None, + action_l2_penalty_coef=None, + obs_normalizer=None, ): self.model = nn.ModuleList([policy, q_func]) + self.obs_normalizer = obs_normalizer if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) + if self.obs_normalizer is not None: + self.obs_normalizer.to(self.device) else: self.device = torch.device("cpu") @@ -119,6 +125,8 @@ def __init__( ) self.batch_states = batch_states self.burnin_action_func = burnin_action_func + self.clip_return_range = clip_return_range + self.action_l2_penalty_coef = action_l2_penalty_coef self.t = 0 self.last_state = None @@ -163,6 +171,8 @@ def compute_critic_loss(self, batch): target_q = batch_rewards + self.gamma * ( 1.0 - batch_terminal ) * next_q.reshape((batchsize,)) + if self.clip_return_range is not None: + target_q = target_q.clamp(*self.clip_return_range) predict_q = self.q_function((batch_state, batch_actions)).reshape((batchsize,)) @@ -181,6 +191,9 @@ def compute_actor_loss(self, batch): q = self.q_function((batch_state, onpolicy_actions)) loss = -q.mean() + if self.action_l2_penalty_coef is not None: + loss += self.action_l2_penalty_coef * (onpolicy_actions ** 2).mean() + # Update stats self.q_record.extend(q.detach().cpu().numpy()) self.actor_loss_record.append(float(loss.detach().cpu().numpy())) @@ -192,6 +205,10 @@ def update(self, experiences, errors_out=None): batch = batch_experiences(experiences, self.device, self.phi, self.gamma) + if self.obs_normalizer: + batch["state"] = self.obs_normalizer(batch["state"], update=False) + batch["next_state"] = self.obs_normalizer(batch["next_state"], update=False) + self.critic_optimizer.zero_grad() self.compute_critic_loss(batch).backward() self.critic_optimizer.step() @@ -258,6 +275,8 @@ def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): def _batch_select_greedy_actions(self, batch_obs): with torch.no_grad(), evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) + if self.obs_normalizer: + batch_xs = self.obs_normalizer(batch_xs, update=False) batch_action = self.policy(batch_xs).sample() return batch_action.cpu().numpy() @@ -300,6 +319,12 @@ def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset) is_state_terminal=batch_done[i], env_id=i, ) + if self.obs_normalizer is not None: + self.obs_normalizer.experience( + self.batch_states( + [self.batch_last_obs[i]], self.device, self.phi + ) + ) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None diff --git a/pfrl/replay_buffers/__init__.py b/pfrl/replay_buffers/__init__.py index 1c5b0ef2b..d33f1bfba 100644 --- a/pfrl/replay_buffers/__init__.py +++ b/pfrl/replay_buffers/__init__.py @@ -1,4 +1,8 @@ from pfrl.replay_buffers.episodic import EpisodicReplayBuffer # NOQA +from pfrl.replay_buffers.hindsight import HindsightReplayBuffer # NOQA +from pfrl.replay_buffers.hindsight import HindsightReplayStrategy # NOQA +from pfrl.replay_buffers.hindsight import ReplayFinalGoal # NOQA +from pfrl.replay_buffers.hindsight import ReplayFutureGoal # NOQA from pfrl.replay_buffers.persistent import PersistentEpisodicReplayBuffer # NOQA from pfrl.replay_buffers.persistent import PersistentReplayBuffer # NOQA from pfrl.replay_buffers.prioritized import PrioritizedReplayBuffer # NOQA diff --git a/pfrl/replay_buffers/hindsight.py b/pfrl/replay_buffers/hindsight.py new file mode 100644 index 000000000..284e96985 --- /dev/null +++ b/pfrl/replay_buffers/hindsight.py @@ -0,0 +1,148 @@ +import copy + +import numpy as np + +from pfrl.replay_buffer import random_subseq +from pfrl.replay_buffers.episodic import EpisodicReplayBuffer + + +def relabel_transition_goal(transition, goal_transition, reward_fn, swap_keys_list): + # Relabel/replace the desired goal for the transition with new_goal + for desired_obs_key, achieved_obs_key in swap_keys_list: + replacement = goal_transition["next_state"][achieved_obs_key] + transition["state"][desired_obs_key] = replacement + transition["next_state"][desired_obs_key] = replacement + new_goal = goal_transition["next_state"]["achieved_goal"] + achieved_goal = transition["next_state"]["achieved_goal"] + transition["reward"] = reward_fn(new_goal, achieved_goal) + return transition + + +class HindsightReplayStrategy: + """ReplayStrategy for Hindsight experience replay.""" + + def apply(self, episodes, reward_fn, swap_keys_list): + return episodes + + +class ReplayFinalGoal(HindsightReplayStrategy): + """Replay final goal.""" + + def apply(self, episodes, reward_fn, swap_keys_list): + batch_size = len(episodes) + episode_lens = np.array([len(episode) for episode in episodes]) + + # Randomly select time-steps from each episode + ts = [np.random.randint(ep_len) for ep_len in episode_lens] + ts = np.array(ts) + + # Select subset for hindsight goal replacement. + apply_hers = np.random.uniform(size=batch_size) < 0.5 + + batch = [] + for episode, apply_her, t in zip(episodes, apply_hers, ts): + transition = episode[t] + if apply_her: + final_transition = episode[-1] + transition = copy.deepcopy(transition) + transition = relabel_transition_goal( + transition, final_transition, reward_fn, swap_keys_list + ) + batch.append([transition]) + return batch + + +class ReplayFutureGoal(HindsightReplayStrategy): + """Replay random future goal. + + Args: + future_k (int): number of future goals to sample per true sample + """ + + def __init__(self, future_k=4): + self.future_prob = 1.0 - 1.0 / (float(future_k) + 1) + + def apply(self, episodes, reward_fn, swap_keys_list): + """Sample with the future strategy""" + batch_size = len(episodes) + episode_lens = np.array([len(episode) for episode in episodes]) + + # Randomly select time-steps from each episode + ts = [np.random.randint(ep_len) for ep_len in episode_lens] + ts = np.array(ts) + + # Select subset for hindsight goal replacement. future_k controls ratio + apply_hers = np.random.uniform(size=batch_size) < self.future_prob + + # Randomly select offsets for future goals + future_offset = np.random.uniform(size=batch_size) * (episode_lens - ts) + future_offset = future_offset.astype(int) + future_ts = ts + future_offset + batch = [] + for episode, apply_her, t, future_t in zip(episodes, apply_hers, ts, future_ts): + transition = episode[t] + if apply_her: + future_transition = episode[future_t] + transition = copy.deepcopy(transition) + transition = relabel_transition_goal( + transition, future_transition, reward_fn, swap_keys_list + ) + batch.append([transition]) + return batch + + +class HindsightReplayBuffer(EpisodicReplayBuffer): + """Hindsight Replay Buffer + + https://arxiv.org/abs/1707.01495 + We currently do not support N-step transitions for the + Hindsight Buffer. + Args: + reward_fn(fn): reward fn with input: (achieved_goal, desired_goal) + replay_strategy: instance of HindsightReplayStrategy() + capacity (int): Capacity of the replay buffer + swap_list (list): a list of tuples of keys to swap in the + observation. E.g. [(("desired_x", "achieved_x"))] This is used + to replace a transition's "desired_x" with a goal transition's + "achieved_x" + """ + + def __init__( + self, + reward_fn, + replay_strategy, + capacity=None, + swap_list=[("desired_goal", "achieved_goal")], + ): + + assert replay_strategy is not None + self.reward_fn = reward_fn + self.replay_strategy = replay_strategy + self.swap_keys_list = swap_list + assert ("desired_goal", "achieved_goal") in self.swap_keys_list + + super(HindsightReplayBuffer, self).__init__(capacity) + # probability of sampling a future goal instead of a true goal + + def sample(self, n): + # Sample n transitions from the hindsight replay buffer + assert len(self.memory) >= n + # Select n episodes + episodes = self.sample_episodes(n) + batch = self.replay_strategy.apply( + episodes, self.reward_fn, self.swap_keys_list + ) + return batch + + def sample_episodes(self, n_episodes, max_len=None): + episodes = self.sample_with_replacement(n_episodes) + if max_len is not None: + return [random_subseq(ep, max_len) for ep in episodes] + else: + return episodes + + def sample_with_replacement(self, k): + return [ + self.episodic_memory[i] + for i in np.random.randint(0, len(self.episodic_memory), k) + ] diff --git a/tests/replay_buffers_test/test_replay_buffer.py b/tests/replay_buffers_test/test_replay_buffer.py index bf2b2b037..0eceb11e7 100644 --- a/tests/replay_buffers_test/test_replay_buffer.py +++ b/tests/replay_buffers_test/test_replay_buffer.py @@ -317,6 +317,200 @@ def test_save_and_load(self): assert rbuf.n_episodes == 2 +@pytest.mark.parametrize("capacity", [100, None]) +@pytest.mark.parametrize("num_steps", [1, 3]) +class TestHindsightReplayBuffer: + @pytest.fixture(autouse=True) + def setUp(self, capacity, num_steps): + self.capacity = capacity + self.num_steps = num_steps + + def test_append_and_sample(self): + capacity = self.capacity + num_steps = self.num_steps + rbuf = replay_buffers.ReplayBuffer(capacity, num_steps) + + assert len(rbuf) == 0 + + # Add one and sample one + correct_item = collections.deque([], maxlen=num_steps) + for _ in range(num_steps): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item.append(trans1) + rbuf.append(**trans1) + assert len(rbuf) == 1 + s1 = rbuf.sample(1) + assert len(s1) == 1 + assert s1[0] == list(correct_item) + + # Add two and sample two, which must be unique + correct_item2 = copy.deepcopy(correct_item) + trans2 = dict( + state=1, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item2.append(trans2) + rbuf.append(**trans2) + assert len(rbuf) == 2 + s2 = rbuf.sample(2) + assert len(s2) == 2 + if s2[0][num_steps - 1]["state"] == 0: + assert s2[0] == list(correct_item) + assert s2[1] == list(correct_item2) + else: + assert s2[1] == list(correct_item) + assert s2[0] == list(correct_item2) + + def test_append_and_terminate(self): + capacity = self.capacity + num_steps = self.num_steps + rbuf = replay_buffers.ReplayBuffer(capacity, num_steps) + + assert len(rbuf) == 0 + + # Add one and sample one + for _ in range(num_steps): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + rbuf.append(**trans1) + assert len(rbuf) == 1 + s1 = rbuf.sample(1) + assert len(s1) == 1 + + # Add two and sample two, which must be unique + trans2 = dict( + state=1, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=True, + ) + rbuf.append(**trans2) + assert len(rbuf) == self.num_steps + 1 + s2 = rbuf.sample(self.num_steps + 1) + assert len(s2) == self.num_steps + 1 + if self.num_steps == 1: + if s2[0][0]["state"] == 0: + assert s2[1][0]["state"] == 1 + else: + assert s2[1][0]["state"] == 0 + else: + for item in s2: + # e.g. if states are 0,0,0,1 then buffer looks like: + # [[0,0,0], [0, 0, 1], [0, 1], [1]] + if len(item) < self.num_steps: + assert item[len(item) - 1]["state"] == 1 + for i in range(len(item) - 1): + assert item[i]["state"] == 0 + else: + for i in range(len(item) - 1): + assert item[i]["state"] == 0 + + def test_stop_current_episode(self): + capacity = self.capacity + num_steps = self.num_steps + rbuf = replay_buffers.ReplayBuffer(capacity, num_steps) + + assert len(rbuf) == 0 + + # Add one and sample one + for _ in range(num_steps - 1): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + rbuf.append(**trans1) + # we haven't experienced n transitions yet + assert len(rbuf) == 0 + # episode ends + rbuf.stop_current_episode() + # episode ends, so we should add n-1 transitions + assert len(rbuf) == self.num_steps - 1 + + def test_save_and_load(self): + capacity = self.capacity + num_steps = self.num_steps + + tempdir = tempfile.mkdtemp() + + rbuf = replay_buffers.ReplayBuffer(capacity, num_steps) + + correct_item = collections.deque([], maxlen=num_steps) + # Add two transitions + for _ in range(num_steps): + trans1 = dict( + state=0, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item.append(trans1) + rbuf.append(**trans1) + correct_item2 = copy.deepcopy(correct_item) + trans2 = dict( + state=1, + action=1, + reward=2, + next_state=3, + next_action=4, + is_state_terminal=False, + ) + correct_item2.append(trans2) + rbuf.append(**trans2) + + # Now it has two transitions + assert len(rbuf) == 2 + + # Save + filename = os.path.join(tempdir, "rbuf.pkl") + rbuf.save(filename) + + # Initialize rbuf + rbuf = replay_buffers.ReplayBuffer(capacity) + + # Of course it has no transition yet + assert len(rbuf) == 0 + + # Load the previously saved buffer + rbuf.load(filename) + + # Now it has two transitions again + assert len(rbuf) == 2 + + # And sampled transitions are exactly what I added! + s2 = rbuf.sample(2) + if s2[0][num_steps - 1]["state"] == 0: + assert s2[0] == list(correct_item) + assert s2[1] == list(correct_item2) + else: + assert s2[0] == list(correct_item2) + assert s2[1] == list(correct_item) + + @pytest.mark.parametrize("capacity", [100, None]) @pytest.mark.parametrize("normalize_by_max", ["batch", "memory"]) class TestPrioritizedReplayBuffer: