Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions atariari/benchmark/episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import os
from .envs import make_vec_envs
from .utils import download_run
from .utils import download_run,appendabledict
try:
import wandb
except:
Expand Down Expand Up @@ -61,8 +61,9 @@ def get_random_agent_rollouts(env_name, steps, seed=42, num_processes=1, num_fra
episodes = list(chain.from_iterable(episodes))
# Convert to 2d list from 3d list
episode_labels = list(chain.from_iterable(episode_labels))
tensor_episodes = [torch.stack(episodes[i]) for i in range(len(episodes))]
envs.close()
return episodes, episode_labels
return tensor_episodes, episode_labels


def get_ppo_rollouts(env_name, steps, seed=42, num_processes=1,
Expand Down Expand Up @@ -110,12 +111,25 @@ def get_ppo_rollouts(env_name, steps, seed=42, num_processes=1,
episode_labels = list(chain.from_iterable(episode_labels))
mean_entropy = torch.stack(entropies).mean()
mean_episode_reward = np.mean(episode_rewards)
tensor_episodes = [torch.stack(episodes[i]) for i in range(len(episodes))]
try:
wandb.log({'action_entropy': mean_entropy, 'mean_reward': mean_episode_reward})
except:
pass

return episodes, episode_labels
return tensor_episodes, episode_labels

def tensorify_labels(eps_labels):
tensor_ep_labels = []
for ep_labels in eps_labels:
ad = appendabledict()
ad.append_updates(ep_labels)
for k in ad.keys():
ad[k] = torch.tensor(ad[k])
tensor_ep_labels.append(ad)

return tensor_ep_labels



def get_episodes(env_name,
Expand All @@ -133,7 +147,7 @@ def get_episodes(env_name,

if collect_mode == "random_agent":
# List of episodes. Each episode is a list of 160x210 observations
episodes, episode_labels = get_random_agent_rollouts(env_name=env_name,
episodes, episodes_labels = get_random_agent_rollouts(env_name=env_name,
steps=steps,
seed=seed,
num_processes=num_processes,
Expand All @@ -143,7 +157,7 @@ def get_episodes(env_name,
elif collect_mode == "pretrained_ppo":
import wandb
# List of episodes. Each episode is a list of 160x210 observations
episodes, episode_labels = get_ppo_rollouts(env_name=env_name,
episodes, episodes_labels = get_ppo_rollouts(env_name=env_name,
steps=steps,
seed=seed,
num_processes=num_processes,
Expand All @@ -156,10 +170,12 @@ def get_episodes(env_name,
else:
assert False, "Collect mode {} not recognized".format(collect_mode)


ep_inds = [i for i in range(len(episodes)) if len(episodes[i]) > min_episode_length]
episodes = [episodes[i] for i in ep_inds]
episode_labels = [episode_labels[i] for i in ep_inds]
episode_labels, entropy_dict = remove_low_entropy_labels(episode_labels, entropy_threshold=entropy_threshold)
episodes_labels = [episodes_labels[i] for i in ep_inds]
episodes_labels, entropy_dict = remove_low_entropy_labels(episodes_labels, entropy_threshold=entropy_threshold)


try:
wandb.log(entropy_dict)
Expand All @@ -182,16 +198,18 @@ def get_episodes(env_name,
"Not enough episodes to split into train, val and test. You must specify more steps"
tr_eps, val_eps, test_eps = episodes[:val_split_ind], episodes[val_split_ind:te_split_ind], episodes[
te_split_ind:]
tr_labels, val_labels, test_labels = episode_labels[:val_split_ind], \
episode_labels[val_split_ind:te_split_ind], episode_labels[te_split_ind:]
tr_labels, val_labels, test_labels = episodes_labels[:val_split_ind], \
episodes_labels[val_split_ind:te_split_ind], episodes_labels[te_split_ind:]
test_eps, test_labels = remove_duplicates(tr_eps, val_eps, test_eps, test_labels)
test_ep_inds = [i for i in range(len(test_eps)) if len(test_eps[i]) > 1]
test_eps = [test_eps[i] for i in test_ep_inds]
test_labels = [test_labels[i] for i in test_ep_inds]

tr_labels, val_labels, test_labels = tensorify_labels(tr_labels), tensorify_labels(val_labels), tensorify_labels(test_labels)
return tr_eps, val_eps, tr_labels, val_labels, test_eps, test_labels

if train_mode == "dry_run":
return episodes, episode_labels
return episodes, episodes_labels



Expand Down
39 changes: 20 additions & 19 deletions atariari/benchmark/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, x):
return self.probe(feature_vec)


class ProbeTrainer():
class ProbeTrainer(object):
def __init__(self,
encoder=None,
method_name="my_method",
Expand Down Expand Up @@ -61,27 +61,28 @@ def __init__(self,
self.loss_fn = nn.CrossEntropyLoss()

# bad convention, but these get set in "create_probes"
self.probes = self.early_stoppers = self.optimizers = self.schedulers = None
self.probes = self.early_stoppers = self.optimizers = self.schedulers = self.label_keys = None

def create_probes(self, sample_label):
def create_probes(self, label_keys):
self.label_keys = label_keys
if self.fully_supervised:
assert self.encoder != None, "for fully supervised you must provide an encoder!"
self.probes = {k: FullySupervisedLinearProbe(encoder=self.encoder,
num_classes=self.num_classes).to(self.device) for k in
sample_label.keys()}
self.label_keys}
else:
self.probes = {k: LinearProbe(input_dim=self.feature_size,
num_classes=self.num_classes).to(self.device) for k in sample_label.keys()}
num_classes=self.num_classes).to(self.device) for k in self.label_keys}

self.early_stoppers = {
k: EarlyStopping(patience=self.patience, verbose=False, name=k + "_probe", save_dir=self.save_dir)
for k in sample_label.keys()}
for k in self.label_keys}

self.optimizers = {k: torch.optim.Adam(list(self.probes[k].parameters()),
eps=1e-5, lr=self.lr) for k in sample_label.keys()}
eps=1e-5, lr=self.lr) for k in self.label_keys}
self.schedulers = {
k: torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizers[k], patience=5, factor=0.2, verbose=True,
mode='max', min_lr=1e-5) for k in sample_label.keys()}
mode='max', min_lr=1e-5) for k in self.label_keys}

def generate_batch(self, episodes, episode_labels):
total_steps = sum([len(e) for e in episodes])
Expand All @@ -101,7 +102,9 @@ def generate_batch(self, episodes, episode_labels):
# Get one sample from this episode
t = np.random.randint(len(episode))
xs.append(episode[t])
labels.append_update(episode_labels_batch[ep_ind][t])
label = episode_labels_batch[ep_ind].subslice(slice(t, t+1))
labels.extend_update(label)
labels = {k: torch.tensor(v).long() for k,v in labels.items()}
yield torch.stack(xs).float().to(self.device) / 255., labels

def probe(self, batch, k):
Expand All @@ -125,10 +128,9 @@ def probe(self, batch, k):
return preds

def do_one_epoch(self, episodes, label_dicts):
sample_label = label_dicts[0][0]
epoch_loss, accuracy = {k + "_loss": [] for k in sample_label.keys() if
epoch_loss, accuracy = {k + "_loss": [] for k in self.label_keys if
not self.early_stoppers[k].early_stop}, \
{k + "_acc": [] for k in sample_label.keys() if
{k + "_acc": [] for k in self.label_keys if
not self.early_stoppers[k].early_stop}

data_generator = self.generate_batch(episodes, label_dicts)
Expand All @@ -139,7 +141,7 @@ def do_one_epoch(self, episodes, label_dicts):
optim = self.optimizers[k]
optim.zero_grad()

label = torch.tensor(label).long().to(self.device)
label = label.to(self.device)
preds = self.probe(x, k)

loss = self.loss_fn(preds, label)
Expand All @@ -160,10 +162,9 @@ def do_one_epoch(self, episodes, label_dicts):
return epoch_loss, accuracy

def do_test_epoch(self, episodes, label_dicts):
sample_label = label_dicts[0][0]
accuracy_dict, f1_score_dict = {}, {}
pred_dict, all_label_dict = {k: [] for k in sample_label.keys()}, \
{k: [] for k in sample_label.keys()}
pred_dict, all_label_dict = {k: [] for k in self.label_keys}, \
{k: [] for k in self.label_keys}

# collect all predictions first
data_generator = self.generate_batch(episodes, label_dicts)
Expand All @@ -189,8 +190,8 @@ def do_test_epoch(self, episodes, label_dicts):
def train(self, tr_eps, val_eps, tr_labels, val_labels):
# if not self.encoder:
# assert len(tr_eps[0][0].squeeze().shape) == 2, "if input is a batch of vectors you must specify an encoder!"
sample_label = tr_labels[0][0]
self.create_probes(sample_label)
label_keys = tr_labels[0].keys()
self.create_probes(label_keys)
e = 0
all_probes_stopped = np.all([early_stopper.early_stop for early_stopper in self.early_stoppers.values()])
while (not all_probes_stopped) and e < self.epochs:
Expand All @@ -199,7 +200,7 @@ def train(self, tr_eps, val_eps, tr_labels, val_labels):

val_loss, val_accuracy = self.evaluate(val_eps, val_labels, epoch=e)
# update all early stoppers
for k in sample_label.keys():
for k in self.label_keys:
if not self.early_stoppers[k].early_stop:
self.early_stoppers[k](val_accuracy["val_" + k + "_acc"], self.probes[k])

Expand Down
36 changes: 36 additions & 0 deletions atariari/benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,42 @@ def append_update(self, other_dict):
for k, v in other_dict.items():
self.__getitem__(k).append(v)

def append_updates(self, list_of_dicts):
"""appends current dict's values with values from other_dict

Parameters
----------
other_dict : dict
A dictionary that you want to append to this dictionary


Returns
-------
Nothing. The side effect is this dict's values change

"""
for other_dict in list_of_dicts:
self.append_update(other_dict)

def extend_update(self, other_dict):
"""appends current dict's values with values from other_dict

Parameters
----------
other_dict : dict
A dictionary that you want to append to this dictionary


Returns
-------
Nothing. The side effect is this dict's values change

"""
for k, v in other_dict.items():
self.__getitem__(k).extend(v)




# Thanks Bjarten! (https://github.com/Bjarten/early-stopping-pytorch)
class EarlyStopping(object):
Expand Down