Skip to content

Commit ab5f8b4

Browse files
committed
bigger critic with LN to check if this help with expresivity
1 parent 505ede7 commit ab5f8b4

File tree

6 files changed

+214
-10
lines changed

6 files changed

+214
-10
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from mrunner.helpers.specification_helper import create_experiments_helper
2+
3+
name = globals()["script"][:-3]
4+
5+
# params for all exps
6+
config = {
7+
"env": "challenge",
8+
"exp_tags": [name],
9+
"exp_point": "monk-APPO-T",
10+
"train_for_env_steps": 500_000_000,
11+
"group": "monk-APPO-T",
12+
"character": "mon-hum-neu-mal",
13+
"num_workers": 16,
14+
"num_envs_per_worker": 16,
15+
"worker_num_splits": 2,
16+
"rollout": 32,
17+
"batch_size": 4096, # this equals bs = 128, 128 * 32 = 4096
18+
"async_rl": True,
19+
"serial_mode": False,
20+
"wandb_user": "bartekcupial",
21+
"wandb_project": "sf2_nethack",
22+
"wandb_group": "gmum",
23+
"with_wandb": True,
24+
"use_pretrained_checkpoint": True,
25+
"model_path": "/net/pr2/projects/plgrid/plgggmum_crl/bcupial/sf_checkpoints/amzn-AA-BC_pretrained",
26+
"use_prev_action": True,
27+
"model": "ScaledNet",
28+
"use_resnet": True,
29+
"learning_rate": 0.0001,
30+
"rnn_size": 1738,
31+
"h_dim": 1738,
32+
"gamma": 1.0,
33+
"skip_train": 25_000_000,
34+
"lr_schedule": "linear_decay",
35+
"save_milestones_ith": 25_000_000,
36+
}
37+
38+
params_grid = []
39+
expected_batch_size = 4096
40+
41+
for rollout in [128]:
42+
for target_batch_size in [128]:
43+
batch_size = min(expected_batch_size, min(target_batch_size * rollout, expected_batch_size * 8))
44+
batches_to_accumulate = max(1, (rollout * target_batch_size) // expected_batch_size)
45+
optim_step_every_ith = max(1, batches_to_accumulate // 8)
46+
params_grid.append(
47+
{
48+
"seed": list(range(3)),
49+
"learning_rate": [0.0001],
50+
"freeze": [{"actor_encoder": 0}],
51+
"rollout": [rollout],
52+
"batch_size": [batch_size], # 32 * 512, 64 * 256, 128 * 128
53+
"num_batches_per_epoch": [min(8, batches_to_accumulate)],
54+
"optim_step_every_ith": [optim_step_every_ith],
55+
"target_batch_size": [target_batch_size],
56+
"actor_critic_share_weights": [False],
57+
"critic_add_layernorm": [True],
58+
"critic_replace_bn_with_ln": [True, False],
59+
"critic_mlp_layers": [[512], [512, 512], [512, 512, 512]],
60+
}
61+
)
62+
params_grid.append(
63+
{
64+
"seed": list(range(3)),
65+
"learning_rate": [0.0001],
66+
"freeze": [{"actor_encoder": 0}],
67+
"rollout": [rollout],
68+
"batch_size": [batch_size], # 32 * 512, 64 * 256, 128 * 128
69+
"num_batches_per_epoch": [min(8, batches_to_accumulate)],
70+
"optim_step_every_ith": [optim_step_every_ith],
71+
"target_batch_size": [target_batch_size],
72+
"actor_critic_share_weights": [False],
73+
"critic_mlp_layers": [[512], [512, 512], [512, 512, 512]],
74+
}
75+
)
76+
params_grid.append(
77+
{
78+
"seed": list(range(3)),
79+
"learning_rate": [0.0001],
80+
"freeze": [{"encoder": 0}],
81+
"rollout": [rollout],
82+
"batch_size": [batch_size], # 32 * 512, 64 * 256, 128 * 128
83+
"num_batches_per_epoch": [min(8, batches_to_accumulate)],
84+
"optim_step_every_ith": [optim_step_every_ith],
85+
"target_batch_size": [target_batch_size],
86+
"critic_mlp_layers": [[512], [512, 512], [512, 512, 512]],
87+
}
88+
)
89+
90+
91+
experiments_list = create_experiments_helper(
92+
experiment_name=name,
93+
project_name="sf2_nethack",
94+
with_neptune=False,
95+
script="python3 mrunner_run.py",
96+
python_path=".",
97+
tags=[name],
98+
base_config=config,
99+
params_grid=params_grid,
100+
mrunner_ignore=".mrunnerignore",
101+
)

sample_factory/cfg/cfg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import multiprocessing
23
import os
34
from argparse import ArgumentParser
@@ -565,10 +566,15 @@ def add_model_args(p: ArgumentParser):
565566
p.add_argument(
566567
"--decoder_mlp_layers",
567568
default=[],
568-
type=int,
569-
nargs="*",
569+
type=ast.literal_eval,
570570
help="Optional decoder MLP layers after the policy core. If empty (default) decoder is identity function.",
571571
)
572+
p.add_argument(
573+
"--critic_mlp_layers",
574+
default=[],
575+
type=ast.literal_eval,
576+
help="Optional critic MLP layers after the policy core. If empty (default) critic is a linear function.",
577+
)
572578

573579
p.add_argument(
574580
"--nonlinearity", default="elu", choices=["elu", "relu", "tanh"], type=str, help="Type of nonlinearity to use."

sample_factory/model/actor_critic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150

151151
decoder_out_size: int = self.decoder.get_out_size()
152152

153-
self.critic_linear = nn.Linear(decoder_out_size, 1)
153+
self.critic = model_factory.make_model_critic_func(cfg, self.decoder.get_out_size())
154154
self.action_parameterization = self.get_action_parameterization(decoder_out_size)
155155

156156
self.apply(self.initialize_weights)
@@ -165,7 +165,7 @@ def forward_core(self, head_output: Tensor, rnn_states):
165165

166166
def forward_tail(self, core_output, values_only: bool, sample_actions: bool) -> TensorDict:
167167
decoder_output = self.decoder(core_output)
168-
values = self.critic_linear(decoder_output).squeeze()
168+
values = self.critic(decoder_output).squeeze()
169169

170170
result = TensorDict(values=values)
171171
if values_only:
@@ -212,7 +212,7 @@ def __init__(
212212
self.critic_decoder = model_factory.make_model_decoder_func(cfg, self.critic_core.get_out_size())
213213
self.decoders = [self.actor_decoder, self.critic_decoder]
214214

215-
self.critic_linear = nn.Linear(self.critic_decoder.get_out_size(), 1)
215+
self.critic = model_factory.make_model_critic_func(cfg, self.critic_decoder.get_out_size())
216216
self.action_parameterization = self.get_action_parameterization(self.critic_decoder.get_out_size())
217217

218218
self.apply(self.initialize_weights)
@@ -284,7 +284,7 @@ def forward_tail(self, core_output, values_only: bool, sample_actions: bool) ->
284284

285285
# second core output corresponds to the critic
286286
critic_decoder_output = self.critic_decoder(core_outputs[1])
287-
values = self.critic_linear(critic_decoder_output).squeeze()
287+
values = self.critic(critic_decoder_output).squeeze()
288288

289289
result = TensorDict(values=values)
290290
if values_only:

sample_factory/model/critic.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import math
2+
from abc import ABC
3+
from typing import List
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch import Tensor
8+
9+
from sample_factory.algo.utils.action_distributions import ContinuousActionDistribution
10+
from sample_factory.algo.utils.torch_utils import calc_num_elements
11+
from sample_factory.model.model_utils import ModelModule, create_mlp, nonlinearity
12+
from sample_factory.utils.typing import Config
13+
14+
15+
class Critic(ModelModule, ABC):
16+
pass
17+
18+
19+
class MlpCritic(Critic):
20+
def __init__(self, cfg: Config, critic_input_size: int):
21+
super().__init__(cfg)
22+
self.critic_input_size = critic_input_size
23+
self.critic_out_size = 1
24+
critic_layers: List[int] = cfg.critic_mlp_layers
25+
activation = nonlinearity(cfg)
26+
self.mlp = create_mlp(critic_layers, critic_input_size, activation)
27+
if len(critic_layers) > 0:
28+
self.mlp = torch.jit.script(self.mlp)
29+
30+
mlp_out_size = calc_num_elements(self.mlp, (critic_input_size,))
31+
self.critic_linear = nn.Linear(mlp_out_size, self.critic_out_size)
32+
33+
def forward(self, core_output):
34+
return self.critic_linear(self.mlp(core_output))
35+
36+
37+
class ValueParameterizationContinuousNonAdaptiveStddev(nn.Module):
38+
"""Use a single learned parameter for action stddevs."""
39+
40+
def __init__(self, cfg, core_out_size):
41+
super().__init__()
42+
self.cfg = cfg
43+
44+
# calculate only value means using the critic neural network
45+
self.distribution_linear = nn.Linear(core_out_size, 1)
46+
# stddev is a single learned parameter
47+
initial_stddev = torch.empty([1])
48+
initial_stddev.fill_(math.log(self.cfg.initial_stddev))
49+
self.learned_stddev = nn.Parameter(initial_stddev, requires_grad=True)
50+
51+
def forward(self, actor_core_output: Tensor):
52+
value_means = self.distribution_linear(actor_core_output)
53+
batch_size = value_means.shape[0]
54+
value_stddevs = self.learned_stddev.repeat(batch_size, 1)
55+
value_distribution_params = torch.cat((value_means, value_stddevs), dim=1)
56+
value_distribution = ContinuousActionDistribution(params=value_distribution_params)
57+
return value_distribution_params, value_distribution
58+
59+
60+
class ParametrizedCritic(Critic):
61+
def __init__(self, cfg: Config, critic_input_size: int):
62+
super().__init__(cfg)
63+
self.critic_input_size = critic_input_size
64+
self.critic_out_size = 2
65+
critic_layers: List[int] = cfg.critic_mlp_layers
66+
activation = nonlinearity(cfg)
67+
self.mlp = create_mlp(critic_layers, critic_input_size, activation)
68+
if len(critic_layers) > 0:
69+
self.mlp = torch.jit.script(self.mlp)
70+
71+
mlp_out_size = calc_num_elements(self.mlp, (critic_input_size,))
72+
self.critic_parametrization = ValueParameterizationContinuousNonAdaptiveStddev(cfg, mlp_out_size)
73+
74+
def forward(self, core_output):
75+
value_distribution_params, self.last_value_distribution = self.critic_parametrization(self.mlp(core_output))
76+
values = self.last_value_distribution.sample()
77+
return values
78+
79+
80+
def default_make_critic_func(cfg: Config, critic_input_size: int) -> Critic:
81+
return MlpCritic(cfg, critic_input_size)

sample_factory/model/model_factory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from sample_factory.model.actor_critic import ActorCritic, default_make_actor_critic_func
44
from sample_factory.model.core import ModelCore, default_make_core_func
5+
from sample_factory.model.critic import Critic, default_make_critic_func
56
from sample_factory.model.decoder import Decoder, default_make_decoder_func
67
from sample_factory.model.encoder import Encoder, default_make_encoder_func
78
from sample_factory.utils.typing import ActionSpace, Config, ObsSpace
@@ -11,6 +12,7 @@
1112
MakeEncoderFunc = Callable[[Config, ObsSpace], Encoder]
1213
MakeCoreFunc = Callable[[Config, int], ModelCore]
1314
MakeDecoderFunc = Callable[[Config, int], Decoder]
15+
MakeCriticFunc = Callable[[Config, int], Critic]
1416

1517

1618
class ModelFactory:
@@ -28,6 +30,7 @@ def __init__(self):
2830
self.make_model_encoder_func: MakeEncoderFunc = default_make_encoder_func
2931
self.make_model_core_func: MakeCoreFunc = default_make_core_func
3032
self.make_model_decoder_func: MakeDecoderFunc = default_make_decoder_func
33+
self.make_model_critic_func: MakeCriticFunc = default_make_critic_func
3134

3235
def register_actor_critic_factory(self, make_actor_critic_func: MakeActorCriticFunc):
3336
"""
@@ -59,3 +62,11 @@ def register_decoder_factory(self, make_model_decoder_func: MakeDecoderFunc):
5962
"""
6063
log.debug(f"register_decoder_factory: {make_model_decoder_func}")
6164
self.make_model_decoder_func = make_model_decoder_func
65+
66+
def register_critic_factory(self, make_model_critic_func: MakeCriticFunc):
67+
"""
68+
Override the default decoder with a custom model.
69+
The computational graph structure is: observations -> encoder -> core -> decoder -> actions
70+
"""
71+
log.debug(f"register_critic_factory: {make_model_critic_func}")
72+
self.make_model_critic_func = make_model_critic_func

sf_examples/nethack/train_nethack.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from os.path import join
44
from typing import Callable
55

6+
import torch
67
import torch.nn as nn
78

89
from sample_factory.algo.learning.learner import Learner
@@ -12,6 +13,7 @@
1213
from sample_factory.envs.env_utils import register_env
1314
from sample_factory.model.actor_critic import ActorCritic, default_make_actor_critic_func
1415
from sample_factory.model.encoder import Encoder
16+
from sample_factory.model.model_utils import get_rnn_size
1517
from sample_factory.train import run_rl
1618
from sample_factory.utils.typing import ActionSpace, Config, ObsSpace
1719
from sample_factory.utils.utils import log
@@ -71,7 +73,8 @@ def load_pretrained_checkpoint(model, checkpoint_dir: str, checkpoint_kind: str,
7173
del checkpoint_dict["model"]["returns_normalizer.running_var"]
7274
del checkpoint_dict["model"]["returns_normalizer.count"]
7375

74-
model.load_state_dict(checkpoint_dict["model"])
76+
incompatibile = model.load_state_dict(checkpoint_dict["model"], strict=False)
77+
log.debug(incompatibile)
7578

7679

7780
def load_pretrained_checkpoint_from_shared_weights(
@@ -126,17 +129,19 @@ def hook(module, input, output):
126129
else:
127130
register_hooks(child)
128131

129-
register_hooks(model.critic_encoder)
132+
register_hooks(model)
130133

131134
tmp_env = make_env_func_batched(cfg, env_config=None)
132135
obs, info = tmp_env.reset()
133-
model.critic_encoder(obs)
136+
rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32)
137+
model(obs, rnn_states)
134138

135139
if cfg.critic_replace_bn_with_ln:
136140
replace_batchnorm_with_layernorm(model.critic_encoder)
137141
inject_layernorm_before_activation(model.critic_encoder)
138142

139-
model.critic_linear = linear_layernorm(model.critic_linear)
143+
inject_layernorm_before_activation(model.critic)
144+
model.critic.critic_linear = linear_layernorm(model.critic.critic_linear)
140145

141146
for handle in handles:
142147
handle.remove()

0 commit comments

Comments
 (0)