diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index 32f61b636..a7f5b655d 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -68,6 +68,11 @@ f"Name of the environment. One of {', '.join(registry.ALL_ENVS)}", ) _IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation") +_PLAYGROUND_CONFIG_OVERRIDES = flags.DEFINE_string( + "playground_config_overrides", + None, + "Overrides for the playground env config.", +) _VISION = flags.DEFINE_boolean("vision", False, "Use vision input") _LOAD_CHECKPOINT_PATH = flags.DEFINE_string( "load_checkpoint_path", None, "Path to load checkpoint from" @@ -260,7 +265,12 @@ def main(argv): if _VISION.value: env_cfg.vision = True env_cfg.vision_config.render_batch_size = ppo_params.num_envs - env = registry.load(_ENV_NAME.value, config=env_cfg) + env_cfg_overrides = {} + if _PLAYGROUND_CONFIG_OVERRIDES.value is not None: + env_cfg_overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value) + env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) if _RUN_EVALS.present: ppo_params.run_evals = _RUN_EVALS.value if _LOG_TRAINING_METRICS.present: @@ -269,6 +279,8 @@ def main(argv): ppo_params.training_metrics_steps = _TRAINING_METRICS_STEPS.value print(f"Environment Config:\n{env_cfg}") + if env_cfg_overrides: + print(f"Environment Config Overrides:\n{env_cfg_overrides}\n") print(f"PPO Training Parameters:\n{ppo_params}") # Generate unique experiment name @@ -399,7 +411,9 @@ def progress(num_steps, metrics): # Load evaluation environment. eval_env = None if not _VISION.value: - eval_env = registry.load(_ENV_NAME.value, config=env_cfg) + eval_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) num_envs = 1 if _VISION.value: num_envs = env_cfg.vision_config.render_batch_size @@ -410,7 +424,9 @@ def progress(num_steps, metrics): from rscope import brax as rscope_utils if not _VISION.value: - rscope_env = registry.load(_ENV_NAME.value, config=env_cfg) + rscope_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) rscope_env = wrapper.wrap_for_brax_training( rscope_env, episode_length=ppo_params.episode_length, diff --git a/learning/train_rsl_rl.py b/learning/train_rsl_rl.py index 71228eef8..f7ce76c4b 100644 --- a/learning/train_rsl_rl.py +++ b/learning/train_rsl_rl.py @@ -54,6 +54,12 @@ f"{', '.join(mujoco_playground.registry.ALL_ENVS)}" ), ) +_IMPL = flags.DEFINE_enum("impl", "jax", ["jax", "warp"], "MJX implementation") +_PLAYGROUND_CONFIG_OVERRIDES = flags.DEFINE_string( + "playground_config_overrides", + None, + "Overrides for the playground env config.", +) _LOAD_RUN_NAME = flags.DEFINE_string( "load_run_name", None, "Run name to load from (for checkpoint restoration)." ) @@ -108,8 +114,14 @@ def main(argv): # Load default config from registry env_cfg = registry.get_default_config(_ENV_NAME.value) + env_cfg.impl = _IMPL.value print(f"Environment config:\n{env_cfg}") + env_cfg_overrides = {} + if _PLAYGROUND_CONFIG_OVERRIDES.value is not None: + env_cfg_overrides = json.loads(_PLAYGROUND_CONFIG_OVERRIDES.value) + print(f"Environment config overrides:\n{env_cfg_overrides}\n") + # Generate unique experiment name now = datetime.now() timestamp = now.strftime("%Y%m%d-%H%M%S") @@ -152,7 +164,9 @@ def render_callback(_, state): render_trajectory.append(state) # Create the environment - raw_env = registry.load(_ENV_NAME.value, config=env_cfg) + raw_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) brax_env = wrapper_torch.RSLRLBraxWrapper( raw_env, num_envs, @@ -206,7 +220,9 @@ def render_callback(_, state): policy = runner.get_inference_policy(device=device) # Example: run a single rollout - eval_env = registry.load(_ENV_NAME.value, config=env_cfg) + eval_env = registry.load( + _ENV_NAME.value, config=env_cfg, config_overrides=env_cfg_overrides + ) jit_reset = jax.jit(eval_env.reset) jit_step = jax.jit(eval_env.step) diff --git a/mujoco_playground/_src/manipulation/leap_hand/reorient.py b/mujoco_playground/_src/manipulation/leap_hand/reorient.py index ce8e931d5..6b2716263 100644 --- a/mujoco_playground/_src/manipulation/leap_hand/reorient.py +++ b/mujoco_playground/_src/manipulation/leap_hand/reorient.py @@ -628,4 +628,4 @@ def rand(rng): "actuator_biasprm": actuator_biasprm, }) - return model, in_axes + return model, in_axes \ No newline at end of file