From 52ee6e66d25cf2be28d60e21da5462e8b2e360cd Mon Sep 17 00:00:00 2001 From: Ewa Dobrowolska Date: Wed, 14 May 2025 13:54:04 +0200 Subject: [PATCH 1/3] add simba --- mrunner_exps/cool_mujoco.py | 44 ++++ mrunner_run.py | 3 +- sample_factory/algo/learning/learner.py | 21 +- .../model/action_parameterization.py | 5 +- sample_factory/model/actor_critic.py | 7 +- sample_factory/model/model_utils.py | 8 + sf_examples/mujoco/models/__init__.py | 1 + sf_examples/mujoco/models/simba.py | 232 ++++++++++++++++++ sf_examples/mujoco/train_mujoco.py | 184 +++++++++++++- 9 files changed, 487 insertions(+), 18 deletions(-) create mode 100644 mrunner_exps/cool_mujoco.py create mode 100644 sf_examples/mujoco/models/__init__.py create mode 100644 sf_examples/mujoco/models/simba.py diff --git a/mrunner_exps/cool_mujoco.py b/mrunner_exps/cool_mujoco.py new file mode 100644 index 000000000..0afdd9e14 --- /dev/null +++ b/mrunner_exps/cool_mujoco.py @@ -0,0 +1,44 @@ +from mrunner.helpers.specification_helper import create_experiments_helper + +name = globals()["script"][:-3] + +# params for all exps +config = { + "train_for_env_steps": 1_000_000, + "num_workers": 16, + "num_envs_per_worker": 16, + "worker_num_splits": 2, + "rollout": 32, + "batch_size": 1024, # this equals bs = 128, 128 * 32 = 4096 + "async_rl": True, + "serial_mode": False, + "restart_behavior": "overwrite", + # "device": "cpu", + "with_wandb": True, + "wandb_user": "ideas-ncbr", + "wandb_project": "mujoco plasticity_ed", + "wandb_group": "cool simba", + +} + +# params different between exps +params_grid = [ + { + "seed": list(range(1)), + "env": ["mujoco_hopper"], + "actor_critic_share_weights": [True], + "model": ["simba"], + }, +] + +experiments_list = create_experiments_helper( + experiment_name=name, + project_name="sf2_mujoco", + with_neptune=False, + script="python3 mrunner_run.py", + python_path=".", + tags=[name], + base_config=config, + params_grid=params_grid, + mrunner_ignore=".mrunnerignore", +) diff --git a/mrunner_run.py b/mrunner_run.py index 351181c6b..5030ed610 100644 --- a/mrunner_run.py +++ b/mrunner_run.py @@ -6,7 +6,8 @@ cfg = get_configuration(print_diagnostics=True, with_neptune=False) del cfg["experiment_id"] - run_script = cfg.pop("run_script", "sf_examples.atari.train_atari") + # run_script = cfg.pop("run_script", "sf_examples.atari.train_atari") + run_script = cfg.pop("run_script", "sf_examples.mujoco.train_mujoco") key_pairs = [f"--{key}={value}" for key, value in cfg.items()] cmd = ["python", "-m", run_script] + key_pairs diff --git a/sample_factory/algo/learning/learner.py b/sample_factory/algo/learning/learner.py index 652a501fb..e01775a8f 100644 --- a/sample_factory/algo/learning/learner.py +++ b/sample_factory/algo/learning/learner.py @@ -887,18 +887,18 @@ def _record_summaries(self, train_loop_vars) -> AttrDict: stats.value_loss = var.value_loss stats.exploration_loss = var.exploration_loss - stats.dead_neurons = var.dead_neurons - stats.effective_rank = var.effective_rank - stats.l2_init_loss = var.l2_init_loss + # stats.dead_neurons = var.dead_neurons + # stats.effective_rank = var.effective_rank + # stats.l2_init_loss = var.l2_init_loss if self.cfg.with_rnd: stats.int_rewards = var.int_rewards.mean() stats.curiosity_rewards = var.curiosity_rewards.mean() stats.predictor_loss = var.predictor_loss stats.int_value_loss = var.int_value_loss - if self.train_step % 200 == 0: - stats.per_layer_grad_norms = var.per_layer_grad_norms - stats.per_layer_param_norms = var.per_layer_param_norms + # if self.train_step % 200 == 0: + # stats.per_layer_grad_norms = var.per_layer_grad_norms + # stats.per_layer_param_norms = var.per_layer_param_norms # # Log dead neurons # for layer in var['dead_neurons_dict'].keys(): @@ -995,7 +995,6 @@ def _prepare_batch(self, batch: TensorDict) -> Tuple[TensorDict, int, int]: with torch.no_grad(): # create a shallow copy so we can modify the dictionary # we still reference the same buffers though - print(f"Actions: {batch['actions'].shape}") buff = shallow_recursive_copy(batch) # ignore experience from other agents (i.e. on episode boundary) and from inactive agents @@ -1165,10 +1164,10 @@ def _prepare_batch(self, batch: TensorDict) -> Tuple[TensorDict, int, int]: # likewise, some invalid values of log_prob_actions can cause NaNs or infs buff["log_prob_actions"][invalid_indices] = -1 # -1 seems like a safe value - if self.cfg.with_rnd: - log.debug(f"[RND] rewards={buff['rewards'].mean()}, curiosity_rewards={buff['curiosity_rewards'].mean()}, int_rewards={buff['int_rewards'].mean()}") - else: - log.debug(f"[OLD] rewards={buff['rewards'].mean()}") + # if self.cfg.with_rnd: + # log.debug(f"[RND] rewards={buff['rewards'].mean()}, curiosity_rewards={buff['curiosity_rewards'].mean()}, int_rewards={buff['int_rewards'].mean()}") + # else: + # log.debug(f"[OLD] rewards={buff['rewards'].mean()}") return buff, dataset_size, num_invalids def train(self, batch: TensorDict) -> Optional[Dict]: diff --git a/sample_factory/model/action_parameterization.py b/sample_factory/model/action_parameterization.py index c11cf7c62..0eccfa0a7 100644 --- a/sample_factory/model/action_parameterization.py +++ b/sample_factory/model/action_parameterization.py @@ -8,6 +8,7 @@ get_action_distribution, is_continuous_action_space, ) +from sample_factory.model.model_utils import orthogonal_init class ActionsParameterization(nn.Module): @@ -28,7 +29,7 @@ def __init__(self, cfg, core_out_size, action_space): super().__init__(cfg, action_space) num_action_outputs = calc_num_action_parameters(action_space) - self.distribution_linear = nn.Linear(core_out_size, num_action_outputs) + self.distribution_linear = orthogonal_init(nn.Linear(core_out_size, num_action_outputs), gain=1.0) def forward(self, actor_core_output): """Just forward the FC layer and generate the distribution object.""" @@ -51,7 +52,7 @@ def __init__(self, cfg, core_out_size, action_space): num_action_outputs = calc_num_action_parameters(action_space) # calculate only action means using the policy neural network - self.distribution_linear = nn.Linear(core_out_size, num_action_outputs // 2) + self.distribution_linear = orthogonal_init(nn.Linear(core_out_size, num_action_outputs // 2), gain=1.0) self.tanh_scale: float = cfg.continuous_tanh_scale # stddev is a single learned parameter initial_stddev = torch.empty([num_action_outputs // 2]) diff --git a/sample_factory/model/actor_critic.py b/sample_factory/model/actor_critic.py index e3396c796..a293ae05f 100644 --- a/sample_factory/model/actor_critic.py +++ b/sample_factory/model/actor_critic.py @@ -19,6 +19,7 @@ from sample_factory.utils.normalize import ObservationNormalizer from sample_factory.utils.typing import ActionSpace, Config, ObsSpace from sample_factory.utils.utils import log +# from sample_factory.model.model_utils import orthogonal_init from gymnasium.wrappers.normalize import RunningMeanStd import copy @@ -729,9 +730,9 @@ def default_make_actor_critic_func(cfg: Config, obs_space: ObsSpace, action_spac model_factory = global_model_factory() - if cfg.cleanrl_actor_critic: - return CleanRLActorCritic(model_factory, obs_space, action_space, cfg) - elif cfg.actor_critic_share_weights: + # if cfg.cleanrl_actor_critic: + # return CleanRLActorCritic(model_factory, obs_space, action_space, cfg) + if cfg.actor_critic_share_weights: return ActorCriticSharedWeights(model_factory, obs_space, action_space, cfg) else: return ActorCriticSeparateWeights(model_factory, obs_space, action_space, cfg) diff --git a/sample_factory/model/model_utils.py b/sample_factory/model/model_utils.py index 581f807d2..5a99f40c6 100644 --- a/sample_factory/model/model_utils.py +++ b/sample_factory/model/model_utils.py @@ -8,6 +8,14 @@ from sample_factory.utils.typing import Config +def orthogonal_init(module, gain=1.0): + if isinstance(module, (nn.Linear, nn.Conv2d)): + nn.init.orthogonal_(module.weight, gain=gain) + if module.bias is not None: + module.bias.data.fill_(0.00) + return module + + def get_rnn_size(cfg): if cfg.use_rnn: size = cfg.rnn_size * cfg.rnn_num_layers diff --git a/sf_examples/mujoco/models/__init__.py b/sf_examples/mujoco/models/__init__.py new file mode 100644 index 000000000..eb59b04b3 --- /dev/null +++ b/sf_examples/mujoco/models/__init__.py @@ -0,0 +1 @@ +from sf_examples.mujoco.models.simba import SimBaActorEncoder, SimBaCriticEncoder \ No newline at end of file diff --git a/sf_examples/mujoco/models/simba.py b/sf_examples/mujoco/models/simba.py new file mode 100644 index 000000000..7f676ef06 --- /dev/null +++ b/sf_examples/mujoco/models/simba.py @@ -0,0 +1,232 @@ +import torch +import torch.nn as nn + +from sample_factory.algo.utils.torch_utils import calc_num_elements +from sample_factory.model.encoder import Encoder +from sample_factory.utils.typing import Config, ObsSpace +from sample_factory.algo.utils.running_mean_std import RunningMeanStdInPlace, RunningMeanStd +# from gymnasium.wrappers.normalize import RunningMeanStd +from sample_factory.model.model_utils import orthogonal_init + + +class SimBaEncoder(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace, hidden_dim: int, num_blocks: int, use_max_pool: bool, expansion: int = 4): + + super().__init__(cfg) + self.obs_keys = list(sorted(obs_space.keys())) # always the same order + self.encoders = nn.ModuleDict() + + out_size = 0 + + for obs_key in self.obs_keys: + shape = obs_space[obs_key].shape + + if len(shape) == 1: + self.encoders[obs_key] = SimBaEncoderMLP(obs_space[obs_key].shape[0], hidden_dim, num_blocks, expansion) + elif len(shape) > 1: + self.encoders[obs_key] = SimBaCNN(obs_space[obs_key], hidden_dim, num_blocks, use_max_pool, expansion) + else: + raise NotImplementedError(f"Unsupported observation space {obs_space}") + + # self.encoders[obs_key] = encoder_fn(obs_space[obs_key], hidden_dim, num_blocks, expansion) + out_size += self.encoders[obs_key].get_out_size() + + self.encoder_out_size = out_size + + def forward(self, obs_dict): + if len(self.obs_keys) == 1: + key = self.obs_keys[0] + return self.encoders[key](obs_dict[key]) + + encodings = [] + for key in self.obs_keys: + x = self.encoders[key](obs_dict[key]) + encodings.append(x) + + return torch.cat(encodings, 1) + + def get_out_size(self) -> int: + return self.encoder_out_size + + +class SimBaConvBlock(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels): + super().__init__() + + # GroupNorm with num_groups=1 is equivalent to LayerNorm + self.layer_norm = nn.GroupNorm(1, in_channels) + + self.conv_block = nn.Sequential( + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, bias=False), + nn.ELU(inplace=True), + nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1), + ) + + # Add projection layer if channels change + self.projection = None + if in_channels != out_channels: + self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + identity = x + out = self.layer_norm(x) + out = self.conv_block(out) + + if self.projection is not None: + identity = self.projection(identity) + + return identity + out + + +class SimBaCNN(nn.Module): + def __init__( + self, + obs_space, + hidden_dim=64, + num_blocks=2, + use_max_pool=False, + expansion=2, + ): + super().__init__() + self.hidden_dim = hidden_dim + self.use_max_pool = use_max_pool + in_channels = obs.space["screen_image"].shape[0] + + assert in_channels & (in_channels - 1) == 0, "in_channels must be power of 2" + assert hidden_dim & (hidden_dim - 1) == 0, "hidden_dim must be power of 2" + assert hidden_dim >= in_channels, "hidden_dim must be >= in_channels" + assert not use_max_pool or (use_max_pool and num_blocks <= 4) + + # Calculate number of doublings needed + current_channels = in_channels + self.blocks = [] + + # Initial convolution to project to hidden dimension + self.initial_conv = orthogonal_init( + nn.Conv2d(in_channels, current_channels * 2, kernel_size=3, padding=0, bias=False), + gain=1.0, + ) + current_channels *= 2 + + # SimBa residual blocks + self.blocks = [] + for i in range(num_blocks): + next_channels = min(current_channels * 2, hidden_dim) + self.blocks.append(SimBaConvBlock(current_channels, next_channels * expansion, next_channels)) + if self.use_max_pool: + self.blocks.append(nn.MaxPool2d(kernel_size=2, stride=2)) + current_channels = next_channels + self.blocks = nn.ModuleList(self.blocks) + + # Post-layer normalization + # GroupNorm with num_groups=1 is equivalent to LayerNorm + self.post_norm = nn.GroupNorm(1, current_channels) + + # Global average pooling + self.pooling = nn.AdaptiveAvgPool2d(1) + + def forward(self, x): + # Initial projection + x = self.initial_conv(x) + + # Residual blocks + for block in self.blocks: + x = block(x) + + # Post normalization + x = self.post_norm(x) + + # Global pooling + x = self.pooling(x) + x = x.view(x.size(0), -1) + + return x + + def get_out_size(self): + return self.hidden_dim + + +class SimBaMLPBlock(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.ln = nn.GroupNorm(1, dim) + self.fc1 = orthogonal_init(nn.Linear(dim, hidden_dim), gain=1.0) + self.act = nn.ELU(inplace=True) + self.fc2 = orthogonal_init(nn.Linear(hidden_dim, dim), gain=1.0) + + def forward(self, x): + identity = x + out = self.ln(x) + out = self.act(self.fc1(out)) + out = self.fc2(out) + return out + identity + + +class SimBaEncoderMLP(nn.Module): + def __init__(self, obs_dim, hidden_dim: int, num_blocks: int, expansion: int = 4): + super().__init__() + self.obs_dim = obs_dim + self.hidden_dim = hidden_dim + self.input_projection = orthogonal_init(nn.Linear(obs_dim, self.hidden_dim)) + # self.norm = RunningMeanStdInPlace((self.hidden_dim,)) + # self.norm = RunningMeanStd((self.hidden_dim,)) + + self.blocks = nn.ModuleList( + [SimBaMLPBlock(self.hidden_dim, expansion*self.hidden_dim) for _ in range(num_blocks)] + ) + # self.output_projection = orthogonal_init(nn.Linear(curr_dim/2, obs_dim)) + self.post_ln = nn.GroupNorm(1, self.hidden_dim) + + def forward(self, x): + out = x + out = self.input_projection(out) + # μ, σ, clip = self.norm.forward(out) + out = out.sub(μ).mul(1 / σ).clamp(-clip, clip) + + for block in self.blocks: + out = block(out) + out = self.post_ln(out) + return out + + def get_out_size(self): + return self.hidden_dim + + +class SimBaActorEncoder(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace): + super().__init__(cfg) + + self.model = SimBaEncoder( + cfg=cfg, + obs_space=obs_space, + hidden_dim=cfg.actor_hidden_dim, + num_blocks=cfg.actor_depth, + use_max_pool=cfg.actor_use_max_pool, + expansion=cfg.actor_expansion + ) + + def forward(self, x): + return self.model(x) + + def get_out_size(self): + return self.model.get_out_size() + + +class SimBaCriticEncoder(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace): + super().__init__(cfg) + + self.model = SimBaEncoder( + cfg=cfg, + obs_space=obs_space, + hidden_dim=cfg.critic_hidden_dim, + num_blocks=cfg.critic_depth, + use_max_pool=cfg.critic_use_max_pool, + expansion=cfg.critic_expansion, + ) + + def forward(self, x): + return self.model(x) + + def get_out_size(self): + return self.model.get_out_size() \ No newline at end of file diff --git a/sf_examples/mujoco/train_mujoco.py b/sf_examples/mujoco/train_mujoco.py index 2c2c586a9..9c1f33d26 100644 --- a/sf_examples/mujoco/train_mujoco.py +++ b/sf_examples/mujoco/train_mujoco.py @@ -1,19 +1,201 @@ import sys +import ast +from typing import Callable -from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args + +from sample_factory.cfg.arguments import load_from_path, parse_full_cfg, parse_sf_args from sample_factory.envs.env_utils import register_env from sample_factory.train import run_rl from sf_examples.mujoco.mujoco_params import add_mujoco_env_args, mujoco_override_defaults from sf_examples.mujoco.mujoco_utils import MUJOCO_ENVS, make_mujoco_env +from sample_factory.algo.utils.context import global_model_factory, sf_global_context +from sample_factory.model.encoder import Encoder, MultiInputEncoder +from sample_factory.utils.typing import ActionSpace, Config, ObsSpace +from sample_factory.utils.utils import log, str2bool +from sample_factory.model.actor_critic import ( + ActorCritic, + ActorCriticSeparateWeights, + ActorCriticSharedWeights, + # obs_space_without_action_mask, +) +from sf_examples.mujoco.models import ( + SimBaActorEncoder, + SimBaCriticEncoder, +) + +def add_extra_params_general(parser): + """ + Specify any additional command line arguments for NetHack. + """ + # TODO: add help + p = parser + p.add_argument("--exp_tags", type=str, default="local") + p.add_argument("--exp_point", type=str, default="point-A") + p.add_argument("--group", type=str, default="group2") + p.add_argument("--use_pretrained_checkpoint", type=str2bool, default=False) + p.add_argument("--model", type=str, default="default") + p.add_argument("--model_path", type=str, default=None) + p.add_argument("--supervised_loss_coeff", type=float, default=0.0) + p.add_argument("--kickstarting_loss_coeff", type=float, default=0.0) + p.add_argument("--distillation_loss_coeff", type=float, default=0.0) + p.add_argument("--supervised_loss_decay", type=float, default=1.0) + p.add_argument("--kickstarting_loss_decay", type=float, default=1.0) + p.add_argument("--distillation_loss_decay", type=float, default=1.0) + p.add_argument("--min_supervised_loss_coeff", type=float, default=0.0) + p.add_argument("--min_kickstarting_loss_coeff", type=float, default=0.0) + p.add_argument("--min_distillation_loss_coeff", type=float, default=0.0) + p.add_argument("--substitute_regularization_with_exploration", type=str2bool, default=False) + p.add_argument("--exploration_coeff_on_supervised_loss_coeff", type=float, default=0.0) + p.add_argument("--exploration_coeff_on_kickstarting_loss_coeff", type=float, default=0.0) + p.add_argument("--exploration_coeff_on_distillation_loss_coeff", type=float, default=0.0) + p.add_argument("--teacher_path", type=str, default=None) + p.add_argument("--run_teacher_hs", type=str2bool, default=False) + p.add_argument("--add_stats_to_info", type=str2bool, default=True) + p.add_argument("--capture_video", type=str2bool, default=False) + p.add_argument("--capture_video_ith", type=int, default=100) + p.add_argument("--freeze", type=ast.literal_eval, default={}) + p.add_argument("--unfreeze", type=ast.literal_eval, default={}) + p.add_argument("--freeze_batch_norm", type=str2bool, default=False) + p.add_argument("--skip_train", type=int, default=-1) + p.add_argument("--target_batch_size", type=int, default=128) + p.add_argument("--optim_step_every_ith", type=int, default=1) + p.add_argument("--actor_depth", type=int, default=1) + p.add_argument("--actor_hidden_dim", type=int, default=64) + p.add_argument("--actor_expansion", type=int, default=4) + p.add_argument("--actor_use_max_pool", type=str2bool, default=False) + p.add_argument("--critic_depth", type=int, default=1) + p.add_argument("--critic_hidden_dim", type=int, default=64) + p.add_argument("--critic_expansion", type=int, default=4) + p.add_argument("--critic_use_max_pool", type=str2bool, default=False) + +class ActorCriticDifferentEncoders(ActorCriticSeparateWeights): + def __init__(self, model_factory, obs_space, action_space, cfg): + super().__init__(model_factory, obs_space, action_space, cfg) + + self.actor_encoder = SimBaActorEncoder(cfg, obs_space) + self.actor_core = model_factory.make_model_core_func(cfg, self.actor_encoder.get_out_size()) + + self.critic_encoder = SimBaCriticEncoder(cfg, obs_space) + self.critic_core = model_factory.make_model_core_func(cfg, self.critic_encoder.get_out_size()) + + self.encoders = [self.actor_encoder, self.critic_encoder] + self.cores = [self.actor_core, self.critic_core] + + self.core_func = self._core_rnn if self.cfg.use_rnn else self._core_empty + + self.actor_decoder = model_factory.make_model_decoder_func(cfg, self.actor_core.get_out_size()) + self.critic_decoder = model_factory.make_model_decoder_func(cfg, self.critic_core.get_out_size()) + self.decoders = [self.actor_decoder, self.critic_decoder] + + self.critic_linear = orthogonal_init(nn.Linear(self.critic_decoder.get_out_size(), 1), gain=1.0) + self.action_parameterization = self.get_action_parameterization(self.actor_decoder.get_out_size()) + + self.encoder_outputs_sizes = [encoder.get_out_size() for encoder in self.encoders] + self.rnn_hidden_sizes = [core.core.hidden_size * 2 for core in self.cores] + self.core_outputs_sizes = [decoder.get_out_size() for decoder in self.decoders] + + +def make_mujoco_actor_critic(cfg: Config, obs_space: ObsSpace, action_space: ActionSpace) -> ActorCritic: + from sample_factory.algo.utils.context import global_model_factory + + model_factory = global_model_factory() + # obs_space = obs_space_without_action_mask(obs_space) + + if cfg.model == "simba": + if cfg.actor_critic_share_weights: + return ActorCriticSharedWeights(model_factory, obs_space, action_space, cfg) + else: + return ActorCriticDifferentEncoders(model_factory, obs_space, action_space, cfg) + elif cfg.model == "default": + if cfg.actor_critic_share_weights: + return ActorCriticSharedWeights(model_factory, obs_space, action_space, cfg) + else: + return ActorCriticSeparateWeights(model_factory, obs_space, action_space, cfg) + else: + raise NotImplementedError + +def load_pretrained_checkpoint(model, checkpoint_dir: str, checkpoint_kind: str, normalize_returns: bool = True): + name_prefix = dict(latest="checkpoint", best="best")[checkpoint_kind] + checkpoints = Learner.get_checkpoints(join(checkpoint_dir, "checkpoint_p0"), f"{name_prefix}_*") + checkpoint_dict = Learner.load_checkpoint(checkpoints, "cpu") + + student_params = dict( + filter( + lambda x: x[0].startswith("student"), + checkpoint_dict["model"].items(), + ) + ) + + if len(student_params) > 0: + # means that the pretrained checkpoint was a KickStarter, we only want to load the student + student_params = dict(map(lambda x: (x[0].removeprefix("student."), x[1]), student_params.items())) + checkpoint_dict["model"] = student_params + + if not normalize_returns: + del checkpoint_dict["model"]["returns_normalizer.running_mean"] + del checkpoint_dict["model"]["returns_normalizer.running_var"] + del checkpoint_dict["model"]["returns_normalizer.count"] + else: + checkpoint_dict["model"]["returns_normalizer.running_mean"][:] = 0 + checkpoint_dict["model"]["returns_normalizer.running_var"][:] = 1 + checkpoint_dict["model"]["returns_normalizer.count"][:] = 1 + + # TODO: handle loading linear critic + model.load_state_dict(checkpoint_dict["model"], strict=False) + + +def load_pretrained_checkpoint_from_shared_weights( + model: ActorCritic, + cfg: Config, + checkpoint_dir: str, + checkpoint_kind: str, + create_model: Callable, + obs_space: ObsSpace, + action_space: ActionSpace, +): + # since our pretrained checkpoints have shared weights we load them in that format + # then create temporary model with separate actor and critic with modules from pretrained model + # we finally use load_state_dict to ensure that the shapes match + cfg.actor_critic_share_weights = True + model_shared = create_model(cfg, obs_space, action_space) + load_pretrained_checkpoint(model_shared, checkpoint_dir, checkpoint_kind, normalize_returns=cfg.normalize_returns) + cfg.actor_critic_share_weights = False + tmp_model: ActorCritic = create_model(cfg, obs_space, action_space) + + tmp_model.obs_normalizer = copy.deepcopy(model_shared.obs_normalizer) + tmp_model.returns_normalizer = copy.deepcopy(model_shared.returns_normalizer) + tmp_model.actor_encoder = copy.deepcopy(model_shared.encoder) + tmp_model.actor_core = copy.deepcopy(model_shared.core) + tmp_model.actor_decoder = copy.deepcopy(model_shared.decoder) + tmp_model.action_parameterization = copy.deepcopy(model_shared.action_parameterization) + + if cfg.init_critic_from_actor: + tmp_model.critic_encoder = copy.deepcopy(model_shared.encoder) + tmp_model.critic_core = copy.deepcopy(model_shared.core) + tmp_model.critic_decoder = copy.deepcopy(model_shared.decoder) + # tmp_model.critic_linear = copy.deepcopy(model_shared.critic_linear) + + model.load_state_dict(tmp_model.state_dict(), strict=False) + + + +def make_mujoco_encoder(cfg: Config, obs_space: ObsSpace) -> Encoder: + if cfg.model == "default": + return MultiInputEncoder(cfg, obs_space) + elif cfg.model == "simba": + return SimBaActorEncoder(cfg, obs_space) def register_mujoco_components(): for env in MUJOCO_ENVS: register_env(env.name, make_mujoco_env) + global_model_factory().register_encoder_factory(make_mujoco_encoder) + global_model_factory().register_actor_critic_factory(make_mujoco_actor_critic) def parse_mujoco_cfg(argv=None, evaluation=False): parser, partial_cfg = parse_sf_args(argv=argv, evaluation=evaluation) + add_extra_params_general(parser) add_mujoco_env_args(partial_cfg.env, parser) mujoco_override_defaults(partial_cfg.env, parser) final_cfg = parse_full_cfg(parser, argv) From 96a5529552d11ac40254912d8411285f0e081552 Mon Sep 17 00:00:00 2001 From: Ewa Dobrowolska Date: Mon, 19 May 2025 20:24:45 +0200 Subject: [PATCH 2/3] update --- mrunner_exps/atari_reproduce_paper.py | 8 +- mrunner_exps/cool_mujoco.py | 4 +- sample_factory/algo/utils/running_mean_std.py | 83 +++++++++++++++++++ sample_factory/cfg/cfg.py | 20 +++++ sf_examples/atari/train_atari.py | 2 +- sf_examples/mujoco/models/simba.py | 2 +- 6 files changed, 111 insertions(+), 8 deletions(-) diff --git a/mrunner_exps/atari_reproduce_paper.py b/mrunner_exps/atari_reproduce_paper.py index 98dbe89e7..f77d3d277 100644 --- a/mrunner_exps/atari_reproduce_paper.py +++ b/mrunner_exps/atari_reproduce_paper.py @@ -5,7 +5,7 @@ # params for all exps config = { "exp_tags": [name], - "train_for_env_steps": 200_000_000, + "train_for_env_steps": 1_000_000, "num_workers": 4, "num_envs_per_worker": 8, "num_batches_per_epoch": 16, @@ -17,11 +17,11 @@ "wandb_project": "atari", "wandb_group": "plasticity, reproduce paper -- 3", "wandb_tags": [name], - "with_wandb": True, + "with_wandb": False, } # params different between exps -atari_games = ["breakout", "montezuma", "phoenix", "namethisgame"] +atari_games = ["breakout"] params_grid = [] @@ -56,7 +56,7 @@ "optimizer": ["adam"], "num_epochs": [8], "normalize_returns": [True], - "repeat_action_probability": [0.0, 0.25], + "repeat_action_probability": [0.0], # paper's params: plasticity "delta": [0.99], diff --git a/mrunner_exps/cool_mujoco.py b/mrunner_exps/cool_mujoco.py index 0afdd9e14..570a47223 100644 --- a/mrunner_exps/cool_mujoco.py +++ b/mrunner_exps/cool_mujoco.py @@ -13,8 +13,8 @@ "async_rl": True, "serial_mode": False, "restart_behavior": "overwrite", - # "device": "cpu", - "with_wandb": True, + "device": "cpu", + # "with_wandb": True, "wandb_user": "ideas-ncbr", "wandb_project": "mujoco plasticity_ed", "wandb_group": "cool simba", diff --git a/sample_factory/algo/utils/running_mean_std.py b/sample_factory/algo/utils/running_mean_std.py index 9d0848890..84d13e621 100644 --- a/sample_factory/algo/utils/running_mean_std.py +++ b/sample_factory/algo/utils/running_mean_std.py @@ -149,3 +149,86 @@ def running_mean_std_summaries(running_mean_std_module: Union[nn.Module, ScriptM res[name.replace("_var", "_std")] = torch.sqrt(buf.float() + _NORM_EPS).mean() return res + + +class RunningMeanStd(nn.Module): + def __init__(self, input_shape, epsilon=_NORM_EPS, clip=_DEFAULT_CLIP, per_channel=False, norm_only=False): + super().__init__() + log.debug("RunningMeanStd input shape: %r", input_shape) + self.input_shape: Final = input_shape + self.eps: Final[float] = epsilon + self.clip: Final[float] = clip + + self.norm_only: Final[bool] = norm_only + self.per_channel: Final[bool] = per_channel + + if per_channel: + if len(self.input_shape) == 3: + self.axis = [0, 2, 3] + if len(self.input_shape) == 2: + self.axis = [0, 2] + if len(self.input_shape) == 1: + self.axis = [0] + shape = self.input_shape[0] + else: + self.axis = [0] + shape = input_shape + + self.register_buffer("running_mean", torch.zeros(shape, dtype=torch.float64)) + self.register_buffer("running_var", torch.ones(shape, dtype=torch.float64)) + self.register_buffer("count", torch.ones([1], dtype=torch.float64)) + + @staticmethod + @torch.jit.script + def _update_mean_var_count_from_moments( + mean: Tensor, var: Tensor, count: Tensor, batch_mean: Tensor, batch_var: Tensor, batch_count: int + ): + delta = batch_mean - mean + tot_count = count + batch_count + + new_mean = mean + delta * batch_count / tot_count + m_a = var * count + m_b = batch_var * batch_count + M2 = m_a + m_b + (delta**2) * count * batch_count / tot_count + new_var = M2 / tot_count + return new_mean, new_var, tot_count + + def forward(self, x: Tensor, denormalize: bool = False) -> None: + """This function modifies the input tensor and returns the normalized tensor.""" + x_copy = x.clone() + if self.training and not denormalize: + # check if the shape exactly matches or it's a scalar for which we use shape (1, ) + assert x_copy.shape[1:] == self.input_shape or ( + x_copy.shape[1:] == () and self.input_shape == (1,) + ), f"RMS expected input shape {self.input_shape}, got {x_copy.shape[1:]}" + + batch_count = x_copy.size()[0] + μ = x_copy.mean(self.axis) # along channel axis + σ2 = x_copy.var(self.axis) + self.running_mean[:], self.running_var[:], self.count[:] = self._update_mean_var_count_from_moments( + self.running_mean, self.running_var, self.count, μ, σ2, batch_count + ) + + # change shape + if self.per_channel: + if len(self.input_shape) == 3: + current_mean = self.running_mean.view([1, self.input_shape[0], 1, 1]).expand_as(x) + current_var = self.running_var.view([1, self.input_shape[0], 1, 1]).expand_as(x) + elif len(self.input_shape) == 2: + current_mean = self.running_mean.view([1, self.input_shape[0], 1]).expand_as(x) + current_var = self.running_var.view([1, self.input_shape[0], 1]).expand_as(x) + elif len(self.input_shape) == 1: + current_mean = self.running_mean.view([1, self.input_shape[0]]).expand_as(x) + current_var = self.running_var.view([1, self.input_shape[0]]).expand_as(x) + else: + raise RuntimeError(f"RunningMeanStd input shape {self.input_shape} not supported") + else: + current_mean = self.running_mean + current_var = self.running_var + + μ = current_mean.float() + σ2 = current_var.float() + σ = torch.sqrt(σ2 + self.eps) + clip = self.clip + + return μ, σ, clip \ No newline at end of file diff --git a/sample_factory/cfg/cfg.py b/sample_factory/cfg/cfg.py index 4e3ebb1f2..9e636f5f6 100644 --- a/sample_factory/cfg/cfg.py +++ b/sample_factory/cfg/cfg.py @@ -283,6 +283,18 @@ def add_rl_args(p: ArgumentParser): help="c_hat clipping parameter of the V-trace algorithm. Low values for c_hat can reduce variance of the advantage estimates (similar to GAE lambda < 1)", ) + # old other params + p.add_argument("--critic_layer_norm", type=str2bool, default=False) + p.add_argument("--encoder_conv_scale", type=int, default=1) + p.add_argument( + "--critic_learning_rate", + type=float, + default=None, + help="this parameter doesn't work with with lr_scheduler, it will be overwritten.", + ) + p.add_argument("--remove_critic", type=str2bool, default=False) + + # plasticity p.add_argument("--tau", type=float, default=0.1, help="Threshold for dead/dormant neurons") # Default: as in the paper p.add_argument("--delta", type=float, default=0.99, help="Threshold for effective rank") # Default: as in the paper @@ -295,6 +307,14 @@ def add_rl_args(p: ArgumentParser): p.add_argument("--modules_to_perturb", default=None, type=ast.literal_eval, help="List of modules that Shrink&Perturb will be applied to, default: all of them (the entire actor-critic)") # Default: as in the paper p.add_argument("--freeze_predictor", default=0, type=int, help="Number of train steps when predictor is frozen after S+P") + # RND + p.add_argument("--with_rnd", default=False, type=str2bool, help="Enables Random Network Distillation") + p.add_argument("--int_gamma", default=0.99, type=float, help="Gamma used for intrinsic rewards in RND") + p.add_argument("--int_coeff", default=1, type=float, help="coefficient/weight of intrinsic advatages in RND") + p.add_argument("--ext_coeff", default=2, type=float, help="coefficient/weight of extrinsic advatages in RND") + p.add_argument("--keep_prob", default=0.25, type=float, help="proportion of experience used for training predictor network in RND") + p.add_argument("--cleanrl_actor_critic", default=False, type=str2bool, help="Use the same ActorCritic architecture as in CleanRL's RND") + # optimization p.add_argument( "--optimizer", diff --git a/sf_examples/atari/train_atari.py b/sf_examples/atari/train_atari.py index d8e9e6178..5aeb355b6 100644 --- a/sf_examples/atari/train_atari.py +++ b/sf_examples/atari/train_atari.py @@ -44,7 +44,7 @@ def extra_summaries(self, runner: Runner, policy_id: PolicyID, writer: SummaryWr def register_msg_handlers(cfg: Config, runner: Runner): if cfg.env == "atari_montezuma": - log.debug(f"Using motezuma handler") + log.debug(f"Using montezuma handler") # extra functions to calculate room-level heatmaps etc. runner.register_episodic_stats_handler(montezuma_extra_episodic_stats_processing) runner.register_observer(MontezumaExtraSummariesObserver()) diff --git a/sf_examples/mujoco/models/simba.py b/sf_examples/mujoco/models/simba.py index 7f676ef06..71ae1fb13 100644 --- a/sf_examples/mujoco/models/simba.py +++ b/sf_examples/mujoco/models/simba.py @@ -181,7 +181,7 @@ def forward(self, x): out = x out = self.input_projection(out) # μ, σ, clip = self.norm.forward(out) - out = out.sub(μ).mul(1 / σ).clamp(-clip, clip) + # out = out.sub(μ).mul(1 / σ).clamp(-clip, clip) for block in self.blocks: out = block(out) From 51140489a9722eee34c8216745762f716aecca77 Mon Sep 17 00:00:00 2001 From: Ewa Dobrowolska Date: Tue, 27 May 2025 22:51:35 +0200 Subject: [PATCH 3/3] add bro, first version --- exp_runner.py | 179 ++++++++++++++++++++++++++ mrunner_exps/cool_mujoco.py | 24 +++- mrunner_run.py | 5 +- sf_examples/mujoco/models/__init__.py | 3 +- sf_examples/mujoco/models/bro.py | 156 ++++++++++++++++++++++ sf_examples/mujoco/train_mujoco.py | 15 ++- 6 files changed, 372 insertions(+), 10 deletions(-) create mode 100644 exp_runner.py create mode 100644 sf_examples/mujoco/models/bro.py diff --git a/exp_runner.py b/exp_runner.py new file mode 100644 index 000000000..09b31f880 --- /dev/null +++ b/exp_runner.py @@ -0,0 +1,179 @@ +import os +import sys +import subprocess +import importlib.util +from pathlib import Path +import argparse +from typing import List, Tuple, Dict +import json + + +def load_experiment_config(config_file_path: str) -> List[Dict[str, str]]: + experiments = [] + + with open(config_file_path, 'r') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + try: + exp = json.loads(line) + experiments.append(exp) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON on line {line_num}: {e}") + return experiments + + +def create_sbatch_script( + command: str, + script_path: str, + log_path: str, + job_name: str, + venv_path: str, + partition: str, + account: str, + time: str = "24:00:00", + nodes: int = 1, + ntasks: int = 1, + cpus: int = 8, + mem: str = "32G", + gpu: int = 1, +) -> None: + + sbatch_content = f"""#!/bin/bash -l +#SBATCH --job-name={job_name} +#SBATCH --output={log_path} +#SBATCH --error={log_path} +#SBATCH --time={time} +#SBATCH --account={account} +#SBATCH --partition={partition} +#SBATCH --ntasks={ntasks} +#SBATCH --nodes={nodes} +#SBATCH --cpus-per-task={cpus} +#SBATCH --mem={mem} +#SBATCH --gres=gpu:{gpu} + +ml ML-bundle/24.06a + +export WANDB_API_KEY=... +cd "$(dirname "$0")" + +""" + + # Add virtual environment activation if specified + if venv_path: + sbatch_content += f"source {venv_path}/bin/activate" + + # Add the actual command + sbatch_content += f""" + +{command} + +echo "" +echo "Job finished at: $(date)" +""" + + with open(script_path, 'w') as f: + f.write(sbatch_content) + + # Make the script executable + os.chmod(script_path, 0o755) + + +def submit_job(sbatch_script_path: str, dry_run: bool = False) -> str: + if dry_run: + return + + try: + subprocess.run( + ["sbatch", sbatch_script_path], + capture_output=True, + text=True, + check=True + ) + return + except subprocess.CalledProcessError as e: + print(f"Error submitting job {sbatch_script_path}: {e}") + print(f"STDOUT: {e.stdout}") + print(f"STDERR: {e.stderr}") + return + +def main(): + parser = argparse.ArgumentParser(description="Run experiments with SLURM") + parser.add_argument("config_file", help="Path to the experiment configs") + parser.add_argument("base_dir", default="...", help="Storage_dir") + parser.add_argument("--venv", default=".atari_venv", help="Path to virtual environment to activate") + parser.add_argument("--account", default="...", help="Account") + parser.add_argument("--partition", default="...", help="Partition") + parser.add_argument("--time", default="2880", help="Job time limit") + parser.add_argument("--nodes", type=int, default=1, help="Number of nodes (default: 1)") + parser.add_argument("--ntasks", type=int, default=1, help="Tasks per node (default: 1)") + parser.add_argument("--cpus", type=int, default=8, help="CPUs per task (default: 8)") + parser.add_argument("--mem", default="32G", help="Memory per node (default: 32G)") + parser.add_argument("--gpu", default="1", help="Number of gpus (default: 1)") + parser.add_argument("--dry-run", action="store_true", help="Don't actually submit jobs, just show what would be done") + + args = parser.parse_args() + + # Validate inputs + if not os.path.exists(args.config_file): + print(f"Error: Config file {args.config_file} does not exist") + sys.exit(1) + + # Load experiment configuration + print(f"Loading experiment configuration from {args.config_file}...") + try: + experiments = load_experiment_config(args.config_file) + except Exception as e: + print(f"Error loading config: {e}") + sys.exit(1) + + project_name = experiments[0].get('project_name') + unique_name = experiments[0].get('unique_name') + + print(f"Running {len(experiments)} experiments") + + # Create directory structure + base_path = Path(args.base_dir) + project_path = base_path / project_name + unique_path = project_path / unique_name + + unique_path.mkdir(parents=True, exist_ok=True) + + # Generate and submit jobs + job_ids = [] + + for i, exp in enumerate(experiments): + exp_name = f"{exp['name']}_{i}" + exp_dir = unique_path / exp_name + exp_dir.mkdir(exist_ok=True) + + # Paths for script and log files + script_path = exp_dir / f"launch.sbatch" + log_path = exp_dir / f"log.out" + + command = exp.get('command') + + # Create SLURM batch script + create_sbatch_script( + command=command, + script_path=str(script_path), + log_path=str(log_path), + job_name=exp_name, + venv_path=args.venv, + partition=args.partition, + time=args.time, + nodes=args.nodes, + ntasks=args.ntasks, + cpus=args.cpus, + mem=args.mem, + gpu=args.gpu, + account=args.account + ) + + # Submit job + submit_job(str(script_path), dry_run=args.dry_run) + + print(f"Experiment files created in: {unique_path}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mrunner_exps/cool_mujoco.py b/mrunner_exps/cool_mujoco.py index 570a47223..e66daca8e 100644 --- a/mrunner_exps/cool_mujoco.py +++ b/mrunner_exps/cool_mujoco.py @@ -1,4 +1,5 @@ from mrunner.helpers.specification_helper import create_experiments_helper +import json name = globals()["script"][:-3] @@ -24,10 +25,10 @@ # params different between exps params_grid = [ { - "seed": list(range(1)), + "seed": list(range(3)), "env": ["mujoco_hopper"], - "actor_critic_share_weights": [True], - "model": ["simba"], + "actor_critic_share_weights": [True, False], + "model": ["bro"], }, ] @@ -42,3 +43,20 @@ params_grid=params_grid, mrunner_ignore=".mrunnerignore", ) +from mrunner.helpers.client_helper import get_configuration + +exps = [] + +for i, exp in enumerate(experiments_list): + curr_config = {"project_name": exp.project, "unique_name": exp.unique_name, "name": exp.name} + params = exp.parameters + run_script = params.pop("run_script", "sf_examples.mujoco.train_mujoco") + key_pairs = [f"--{key}={value}" for key, value in params.items()] + cmd = ["python", "-m", run_script] + key_pairs + curr_config["command"] = " ".join(cmd) + exps.append(curr_config) + + +with open("config.jsonl", "w") as f: + for item in exps: + f.write(json.dumps(item) + "\n") \ No newline at end of file diff --git a/mrunner_run.py b/mrunner_run.py index 5030ed610..dd31bc123 100644 --- a/mrunner_run.py +++ b/mrunner_run.py @@ -4,11 +4,12 @@ if __name__ == "__main__": cfg = get_configuration(print_diagnostics=True, with_neptune=False) - + # print("Configuration:", cfg) del cfg["experiment_id"] # run_script = cfg.pop("run_script", "sf_examples.atari.train_atari") run_script = cfg.pop("run_script", "sf_examples.mujoco.train_mujoco") key_pairs = [f"--{key}={value}" for key, value in cfg.items()] cmd = ["python", "-m", run_script] + key_pairs - subprocess.run(cmd) + # subprocess.run(cmd) + # print("Running command:", " ".join(cmd)) diff --git a/sf_examples/mujoco/models/__init__.py b/sf_examples/mujoco/models/__init__.py index eb59b04b3..3e67da4d0 100644 --- a/sf_examples/mujoco/models/__init__.py +++ b/sf_examples/mujoco/models/__init__.py @@ -1 +1,2 @@ -from sf_examples.mujoco.models.simba import SimBaActorEncoder, SimBaCriticEncoder \ No newline at end of file +from sf_examples.mujoco.models.simba import SimBaActorEncoder, SimBaCriticEncoder +from sf_examples.mujoco.models.bro import BROActorEncoder, BROCriticEncoder \ No newline at end of file diff --git a/sf_examples/mujoco/models/bro.py b/sf_examples/mujoco/models/bro.py new file mode 100644 index 000000000..df515f5da --- /dev/null +++ b/sf_examples/mujoco/models/bro.py @@ -0,0 +1,156 @@ +import torch +import torch.nn as nn + +from sample_factory.algo.utils.torch_utils import calc_num_elements +from sample_factory.model.encoder import Encoder +from sample_factory.utils.typing import Config, ObsSpace +from sample_factory.algo.utils.running_mean_std import RunningMeanStdInPlace, RunningMeanStd +# from gymnasium.wrappers.normalize import RunningMeanStd +from sample_factory.model.model_utils import orthogonal_init + + +class BROEncoder(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace, hidden_dim: int, num_blocks: int): + + super().__init__(cfg) + self.obs_keys = list(sorted(obs_space.keys())) # always the same order + self.encoders = nn.ModuleDict() + + out_size = 0 + + for obs_key in self.obs_keys: + shape = obs_space[obs_key].shape + + if len(shape) == 1: + self.encoders[obs_key] = BROEncoderMLP(obs_space[obs_key].shape[0], hidden_dim, num_blocks) + elif len(shape) > 1: + raise NotImplementedError(f"Conv encoder not implemented yet") + # self.encoders[obs_key] = BROCNN(obs_space[obs_key], ...) + else: + raise NotImplementedError(f"Unsupported observation space {obs_space}") + + # self.encoders[obs_key] = encoder_fn(obs_space[obs_key], hidden_dim, num_blocks, expansion) + out_size += self.encoders[obs_key].get_out_size() + + self.encoder_out_size = out_size + + def forward(self, obs_dict): + if len(self.obs_keys) == 1: + key = self.obs_keys[0] + return self.encoders[key](obs_dict[key]) + + encodings = [] + for key in self.obs_keys: + x = self.encoders[key](obs_dict[key]) + encodings.append(x) + + return torch.cat(encodings, 1) + + def get_out_size(self) -> int: + return self.encoder_out_size + + +# class BROConvBlock(nn.Module): +# def __init__(self, ...): +# super().__init__() +# ... + +# def forward(self, x): + + +# class BROCNN(nn.Module): +# def __init__( +# self, +# obs_space, +# ..., +# ): +# super().__init__() +# ... + +# def forward(self, x): +# ... + +# def get_out_size(self): +# ... + +class BROMLPBlock(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.ln1 = nn.GroupNorm(1, dim) + self.fc1 = orthogonal_init(nn.Linear(dim, hidden_dim), gain=1.0) + self.ln2 = nn.GroupNorm(1, hidden_dim) + self.fc2 = orthogonal_init(nn.Linear(hidden_dim, dim), gain=1.0) + self.act = nn.ELU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + out = self.ln1(x) + out = self.fc1(out) + out = self.act(out) + out = self.ln2(out) + out = self.fc2(out) + return out + identity + +class BROEncoderMLP(nn.Module): + def __init__(self, + obs_dim: int, + hidden_dim: int = 512, + num_blocks: int = 2): + super().__init__() + self.hidden_dim = hidden_dim + + self.stem = nn.Sequential( + orthogonal_init(nn.Linear(obs_dim, hidden_dim), gain=1.0), + nn.GroupNorm(1, hidden_dim), + nn.ELU(inplace=True) + ) + + self.blocks = nn.ModuleList([BROMLPBlock(hidden_dim, hidden_dim) for _ in range(num_blocks)]) + self.post_ln = nn.GroupNorm(1, hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem(x) + for blk in self.blocks: + x = blk(x) + x = self.post_ln(x) + return x + + def get_out_size(self) -> int: + return self.hidden_dim + + + +class BROActorEncoder(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace): + super().__init__(cfg) + + self.model = BROEncoder( + cfg=cfg, + obs_space=obs_space, + hidden_dim=cfg.actor_hidden_dim, + num_blocks=cfg.actor_depth, + ) + + def forward(self, x): + return self.model(x) + + def get_out_size(self): + return self.model.get_out_size() + + +class BROCriticEncoder(Encoder): + def __init__(self, cfg: Config, obs_space: ObsSpace): + super().__init__(cfg) + + self.model = BROEncoder( + cfg=cfg, + obs_space=obs_space, + hidden_dim=cfg.critic_hidden_dim, + num_blocks=cfg.critic_depth, + ) + + def forward(self, x): + return self.model(x) + + def get_out_size(self): + return self.model.get_out_size() \ No newline at end of file diff --git a/sf_examples/mujoco/train_mujoco.py b/sf_examples/mujoco/train_mujoco.py index 9c1f33d26..639ec8403 100644 --- a/sf_examples/mujoco/train_mujoco.py +++ b/sf_examples/mujoco/train_mujoco.py @@ -21,6 +21,8 @@ from sf_examples.mujoco.models import ( SimBaActorEncoder, SimBaCriticEncoder, + BROActorEncoder, + BROCriticEncoder, ) def add_extra_params_general(parser): @@ -71,11 +73,14 @@ def add_extra_params_general(parser): class ActorCriticDifferentEncoders(ActorCriticSeparateWeights): def __init__(self, model_factory, obs_space, action_space, cfg): super().__init__(model_factory, obs_space, action_space, cfg) + if cfg.model == "simba": + self.actor_encoder = SimBaActorEncoder(cfg, obs_space) + self.critic_encoder = SimBaCriticEncoder(cfg, obs_space) + elif cfg.model == "bro": + self.actor_encoder = BROActorEncoder(cfg, obs_space) + self.critic_encoder = BROCriticEncoder(cfg, obs_space) - self.actor_encoder = SimBaActorEncoder(cfg, obs_space) self.actor_core = model_factory.make_model_core_func(cfg, self.actor_encoder.get_out_size()) - - self.critic_encoder = SimBaCriticEncoder(cfg, obs_space) self.critic_core = model_factory.make_model_core_func(cfg, self.critic_encoder.get_out_size()) self.encoders = [self.actor_encoder, self.critic_encoder] @@ -101,7 +106,7 @@ def make_mujoco_actor_critic(cfg: Config, obs_space: ObsSpace, action_space: Act model_factory = global_model_factory() # obs_space = obs_space_without_action_mask(obs_space) - if cfg.model == "simba": + if cfg.model in ["simba", "bro"]: if cfg.actor_critic_share_weights: return ActorCriticSharedWeights(model_factory, obs_space, action_space, cfg) else: @@ -184,6 +189,8 @@ def make_mujoco_encoder(cfg: Config, obs_space: ObsSpace) -> Encoder: return MultiInputEncoder(cfg, obs_space) elif cfg.model == "simba": return SimBaActorEncoder(cfg, obs_space) + elif cfg.model == "bro": + return BROActorEncoder(cfg, obs_space) def register_mujoco_components():