Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions experiment_code/hackrl/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,5 @@ eval_checkpoint_every: 50_000_000
eval_rollouts: 1024
eval_batch_size: 256
skip_first_eval: False

omitted_dlvls: 0
80 changes: 74 additions & 6 deletions experiment_code/hackrl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

class TtyrecEnvPool:
def __init__(self, flags, dataset_name, dataset_scores, **dataset_kwargs):
self.flags = flags
Copy link
Owner

@BartekCupial BartekCupial Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe its better just to pass this argument omitted_dlvls instead of whole flags?
self.omitted_dlvls = omitted_dlvls

self.idx = 0
self.env_pool_size = flags.ttyrec_envpool_size
self.dataset = dataset.TtyrecDataset(dataset_name, **dataset_kwargs)
Expand Down Expand Up @@ -139,6 +140,58 @@ def _iter():
)
list(self.threadpool.map(convert, range(self.ttyrec_batch_size)))

# Flatten and convert to numpy array
tty_chars = (
mb_tensors["tty_chars"][..., -2:, :].flatten(start_dim=2).detach().cpu().numpy()
)

# Convert numeric values to characters
characters = np.apply_along_axis(
lambda x: "".join(map(chr, x)), axis=2, arr=tty_chars
)

# Split characters by space
words = np.char.split(characters, " ")

# Extract Dlvl values from words
dlvl = np.array(
[
[
[w.split(":")[-1] for w in string if w.startswith("Dlvl:")]
for string in sublist
]
for sublist in words
]
)

def _filter_2d_dlvl_array(frame):
def filter_element(x):
if isinstance(x, list) and len(x) > 0:
return int(x[0])
else:
return None

# Vectorize the filtering function
filter_vectorized = np.vectorize(filter_element, otypes=[object])

# Apply the filtering function to the frame
filtered_frame = filter_vectorized(frame).astype(float)

vals, counts = np.unique(filtered_frame[~np.isnan(filtered_frame)], return_counts=True)
replace_value = vals[np.argmax(counts)] if len(vals)>0 else 0

# Replace None values with dlvl from previous frame
for i in range(filtered_frame.shape[0]):
for j in range(filtered_frame.shape[1]):
if filtered_frame[i, j] is None:
filtered_frame[i, j] = replace_value

return filtered_frame

dlvl = _filter_2d_dlvl_array(dlvl)

mask = torch.from_numpy(dlvl > self.flags.omitted_dlvls) if self.flags.omitted_dlvls > 0 else torch.ones_like(mb_tensors["timestamps"]).bool()

final_mb = {
"tty_chars": mb_tensors["tty_chars"],
"tty_colors": mb_tensors["tty_colors"],
Expand All @@ -147,7 +200,7 @@ def _iter():
"done": mb_tensors["done"].bool(),
"timesteps": mb_tensors["timestamps"].float(),
# "max_scores": max_scores[mb["gameids"].flatten()].reshape(mb["gameids"].shape).float(),
"mask": torch.ones_like(mb_tensors["timestamps"]).bool()
"mask": mask,
}

if "keypresses" in mb_tensors:
Expand Down Expand Up @@ -557,14 +610,23 @@ def compute_entropy_loss(logits, stats=None):
return -torch.mean(entropy_per_timestep)


def compute_kickstarting_loss(student_logits, expert_logits):
def compute_kickstarting_loss(student_logits, expert_logits, mask: torch.Tensor):
T, B, *_ = student_logits.shape
return torch.nn.functional.kl_div(
if mask is None:
return torch.nn.functional.kl_div(
F.log_softmax(student_logits.reshape(T * B, -1), dim=-1),
F.log_softmax(expert_logits.reshape(T * B, -1), dim=-1),
log_target=True,
reduction="batchmean",
)
loss = torch.nn.functional.kl_div(
F.log_softmax(student_logits.reshape(T * B, -1), dim=-1),
F.log_softmax(expert_logits.reshape(T * B, -1), dim=-1),
log_target=True,
reduction="batchmean",
reduction="none",
)
loss = loss.T * mask
return loss.sum() / B / T


def compute_policy_gradient_loss(
Expand Down Expand Up @@ -662,9 +724,12 @@ def compute_gradients(data, sleep_data, learner_state, stats):
logits[:-1], expert[:-1]
)
else:
# TODO: Why do we take up to -1 index here and above?
supervised_loss = (
FLAGS.supervised_loss * F.cross_entropy(logits[:-1], true_a[:-1]).mean()
)
FLAGS.supervised_loss
* F.cross_entropy(logits[:-1], true_a[:-1], reduce=False)
* torch.flatten(ttyrec_data["mask"], 0, 1)[:-1].int()
).mean()
FLAGS.supervised_loss *= FLAGS.supervised_decay
stats["supervised_loss"] += supervised_loss.item()
stats["supervised_coeff"] += FLAGS.supervised_loss
Expand Down Expand Up @@ -831,6 +896,8 @@ def compute_gradients(data, sleep_data, learner_state, stats):
stats["inverse_loss"] += inverse_loss.item()

if FLAGS.use_kickstarting:
# TODO phase 2: add regularization only mask, when we reach a particular lvl

kickstarting_loss = FLAGS.kickstarting_loss * compute_kickstarting_loss(
learner_outputs["policy_logits"],
actor_outputs["kick_policy_logits"],
Expand All @@ -855,6 +922,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
kickstarting_loss_bc = FLAGS.kickstarting_loss_bc * compute_kickstarting_loss(
ttyrec_predictions["policy_logits"],
ttyrec_predictions["kick_policy_logits"],
torch.flatten(ttyrec_data["mask"], 0, 1).int()
)
FLAGS.kickstarting_loss_bc *= FLAGS.kickstarting_decay_bc
total_loss += kickstarting_loss_bc
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from random_word import RandomWords

from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations


name = globals()["script"][:-3]

# params for all exps
config = {
"exp_tags": [name],
"connect":"0.0.0.0:4431",
"exp_set": "2G",
"exp_point": "monk-AA-BC",
"num_actor_cpus": 20,
"total_steps": 2_000_000_000,
"actor_batch_size": 256,
"batch_size": 128,
"ttyrec_batch_size": 512,
"supervised_loss": 1,
"adam_learning_rate": 0.001,
"behavioural_clone": True,
'group': "monk-AA-BC",
"character": "mon-hum-neu-mal",
"omitted_dlvls": 0
}


# params different between exps
params_grid = [
{
"seed": [0, 1],
},
]

params_configurations = get_combinations(params_grid)

final_grid = []
for e, cfg in enumerate(params_configurations):
cfg = {key: [value] for key, value in cfg.items()}
r = RandomWords().get_random_word()
cfg["group"] = [f"{name}_{e}_{r}"]
final_grid.append(dict(cfg))


experiments_list = create_experiments_helper(
experiment_name=name,
project_name="nle",
with_neptune=False,
script="python3 mrunner_run.py",
python_path=".",
tags=[name],
exclude=["checkpoint"],
base_config=config,
params_grid=final_grid,
exclude_git_files=False,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from random_word import RandomWords

from mrunner.helpers.specification_helper import create_experiments_helper, get_combinations


name = globals()["script"][:-3]

# params for all exps
config = {
"exp_tags": [name],
"connect":"0.0.0.0:4431",
"exp_set": "2G",
"exp_point": "monk-AA-BC",
"num_actor_cpus": 20,
"total_steps": 2_000_000_000,
"actor_batch_size": 256,
"batch_size": 128,
"ttyrec_batch_size": 512,
"supervised_loss": 1,
"adam_learning_rate": 0.001,
"behavioural_clone": True,
'group': "monk-AA-BC",
"character": "mon-hum-neu-mal",
"omitted_dlvls":4
}


# params different between exps
params_grid = [
{
"seed": [0, 1, 2],
},
]

params_configurations = get_combinations(params_grid)

final_grid = []
for e, cfg in enumerate(params_configurations):
cfg = {key: [value] for key, value in cfg.items()}
r = RandomWords().get_random_word()
cfg["group"] = [f"{name}_{e}_{r}"]
final_grid.append(dict(cfg))


experiments_list = create_experiments_helper(
experiment_name=name,
project_name="nle",
with_neptune=False,
script="python3 mrunner_run.py",
python_path=".",
tags=[name],
exclude=["checkpoint"],
base_config=config,
params_grid=final_grid,
exclude_git_files=False,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from random_word import RandomWords

from mrunner.helpers.specification_helper import (
create_experiments_helper,
get_combinations,
)


name = globals()["script"][:-3]

# params for all exps
config = {
"connect": "0.0.0.0:4431",
"exp_set": "2G",
"exp_point": "monk-APPO-AA-KL",
"num_actor_cpus": 20,
"total_steps": 2_000_000_000,
"group": "monk-APPO-AA-KL",
"character": "mon-hum-neu-mal",
"use_checkpoint_actor": False,
"ttyrec_batch_size": 256,
"kickstarting_loss_bc": 0.5,
"use_kickstarting_bc": True,
}


# params different between exps
params_grid = [
{
"exp_tags": [f"{name}-4"],
"seed": list(range(5)),
# load from checkpoint
"unfreeze_actor_steps": [0],
"use_checkpoint_actor": [True],
"kickstarting_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip4/checkpoint.tar"
],
"model_checkpoint_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip4/checkpoint.tar"
],
"omitted_dlvls": [4],
},
{
"exp_tags": [f"{name}-3"],
"seed": list(range(5)),
# load from checkpoint
"unfreeze_actor_steps": [0],
"use_checkpoint_actor": [True],
"kickstarting_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip3/checkpoint.tar"
],
"model_checkpoint_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip3/checkpoint.tar"
],
"omitted_dlvls": [3],
},
{
"exp_tags": [f"{name}-2"],
"seed": list(range(5)),
# load from checkpoint
"unfreeze_actor_steps": [0],
"use_checkpoint_actor": [True],
"kickstarting_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip2/checkpoint.tar"
],
"model_checkpoint_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip2/checkpoint.tar"
],
"omitted_dlvls": [2],
},
{
"exp_tags": [f"{name}-1"],
"seed": list(range(5)),
# load from checkpoint
"unfreeze_actor_steps": [0],
"use_checkpoint_actor": [True],
"kickstarting_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip1/checkpoint.tar"
],
"model_checkpoint_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip1/checkpoint.tar"
],
"omitted_dlvls": [1],
},
{
"exp_tags": [f"{name}-0"],
"seed": list(range(5)),
# load from checkpoint
"unfreeze_actor_steps": [0],
"use_checkpoint_actor": [True],
"kickstarting_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip0/checkpoint.tar"
],
"model_checkpoint_path": [
"/net/pr2/projects/plgrid/plgg_pw_crl/mbortkiewicz/mrunner_scratch/checkpoints_nle/skip0/checkpoint.tar"
],
"omitted_dlvls": [0],
},
]

params_configurations = get_combinations(params_grid)

final_grid = []
for e, cfg in enumerate(params_configurations):
cfg = {key: [value] for key, value in cfg.items()}
r = RandomWords().get_random_word()
cfg["group"] = [f"{name}_{e}_{r}"]
final_grid.append(dict(cfg))


experiments_list = create_experiments_helper(
experiment_name=name,
project_name="nle",
with_neptune=False,
script="python3 mrunner_run.py",
python_path=".",
tags=[name],
exclude=["checkpoint"],
base_config=config,
params_grid=final_grid,
exclude_git_files=False,
)
Loading