Skip to content

Commit ff0ea56

Browse files
Andrew-Luo1copybara-github
authored andcommitted
Copybara import of the project:
-- 584f657 by andrew <[email protected]>: rscope supporting changes -- 75b9c39 by andrew <[email protected]>: run formatter -- a46f5e1 by andrew <[email protected]>: elaborate on _model_assets -- d3500c1 by andrew <[email protected]>: add deterministic eval option for train script -- 1635951 by andrew <[email protected]>: modify README.md -- e602864 by andrew <[email protected]>: update naming convention -- 3bdb45f by andrew <[email protected]>: run formatting -- a145b8f by andrew <[email protected]>: add option to skip eval rollouts -- b86274d by andrew <[email protected]>: update readme -- 2ac436f by andrew <[email protected]>: move trace_fn out of train_jax_ppo.py -- 1fa8c32 by andrew <[email protected]>: switch to OOP rscope interface -- 47125c7 by andrew <[email protected]>: switch training script to the new metrics logging api -- ce537d7 by andrew <[email protected]>: improve reward logging -- 5deaaa9 by andrew <[email protected]>: improve training printouts -- 67f93ba by andrew <[email protected]>: resolve PR comments -- 0114224 by andrew <[email protected]>: modify README COPYBARA_INTEGRATE_REVIEW=#129 from Andrew-Luo1:rscope_new 0114224 PiperOrigin-RevId: 764350162 Change-Id: I701bfc55a1ac994d3440d2d2fe03985ac3f399d5
1 parent 65020c6 commit ff0ea56

File tree

11 files changed

+136
-22
lines changed

11 files changed

+136
-22
lines changed

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,24 @@ For vision-based environments, please refer to the installation instructions in
6363
| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb) | Training CartPole from Vision |
6464
| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb) | Robotic Manipulation from Vision |
6565

66+
## Running from CLI
67+
> [!IMPORTANT]
68+
> Assumes installation from source.
69+
70+
For basic usage, navigate to the repo's directory and run:
71+
```bash
72+
python learning/train_jax_ppo.py --env_name CartpoleBalance
73+
```
74+
75+
### Training Visualization
76+
77+
To interactively view trajectories throughout training with [rscope](https://github.com/Andrew-Luo1/rscope/tree/main), install it (`pip install rscope`) and run:
78+
79+
```
80+
python learning/train_jax_ppo.py --env_name PandaPickCube --rscope_envs 16 --run_evals=False --deterministic_rscope=True
81+
# In a separate terminal
82+
python -m rscope
83+
```
6684

6785
## FAQ
6886

learning/train_jax_ppo.py

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,33 @@
132132
"policy_obs_key", "state", "Policy obs key"
133133
)
134134
_VALUE_OBS_KEY = flags.DEFINE_string("value_obs_key", "state", "Value obs key")
135+
_RSCOPE_ENVS = flags.DEFINE_integer(
136+
"rscope_envs",
137+
None,
138+
"Number of parallel environment rollouts to save for the rscope viewer",
139+
)
140+
_DETERMINISTIC_RSCOPE = flags.DEFINE_boolean(
141+
"deterministic_rscope",
142+
True,
143+
"Run deterministic rollouts for the rscope viewer",
144+
)
145+
_RUN_EVALS = flags.DEFINE_boolean(
146+
"run_evals",
147+
True,
148+
"Run evaluation rollouts between policy updates.",
149+
)
150+
_LOG_TRAINING_METRICS = flags.DEFINE_boolean(
151+
"log_training_metrics",
152+
False,
153+
"Whether to log training metrics and callback to progress_fn. Significantly"
154+
" slows down training if too frequent.",
155+
)
156+
_TRAINING_METRICS_STEPS = flags.DEFINE_integer(
157+
"training_metrics_steps",
158+
1_000_000,
159+
"Number of steps between logging training metrics. Increase if training"
160+
" experiences slowdown.",
161+
)
135162

136163

137164
def get_rl_config(env_name: str) -> config_dict.ConfigDict:
@@ -151,6 +178,24 @@ def get_rl_config(env_name: str) -> config_dict.ConfigDict:
151178
raise ValueError(f"Env {env_name} not found in {registry.ALL_ENVS}.")
152179

153180

181+
def rscope_fn(full_states, obs, rew, done):
182+
"""
183+
All arrays are of shape (unroll_length, rscope_envs, ...)
184+
full_states: dict with keys 'qpos', 'qvel', 'time', 'metrics'
185+
obs: nd.array or dict obs based on env configuration
186+
rew: nd.array rewards
187+
done: nd.array done flags
188+
"""
189+
# Calculate cumulative rewards per episode, stopping at first done flag
190+
done_mask = jp.cumsum(done, axis=0)
191+
valid_rewards = rew * (done_mask == 0)
192+
episode_rewards = jp.sum(valid_rewards, axis=0)
193+
print(
194+
"Collected rscope rollouts with reward"
195+
f" {episode_rewards.mean():.3f} +- {episode_rewards.std():.3f}"
196+
)
197+
198+
154199
def main(argv):
155200
"""Run training and evaluation for the specified environment."""
156201

@@ -209,11 +254,16 @@ def main(argv):
209254
ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value
210255
if _VALUE_OBS_KEY.present:
211256
ppo_params.network_factory.value_obs_key = _VALUE_OBS_KEY.value
212-
213257
if _VISION.value:
214258
env_cfg.vision = True
215259
env_cfg.vision_config.render_batch_size = ppo_params.num_envs
216260
env = registry.load(_ENV_NAME.value, config=env_cfg)
261+
if _RUN_EVALS.present:
262+
ppo_params.run_evals = _RUN_EVALS.value
263+
if _LOG_TRAINING_METRICS.present:
264+
ppo_params.log_training_metrics = _LOG_TRAINING_METRICS.value
265+
if _TRAINING_METRICS_STEPS.present:
266+
ppo_params.training_metrics_steps = _TRAINING_METRICS_STEPS.value
217267

218268
print(f"Environment Config:\n{env_cfg}")
219269
print(f"PPO Training Parameters:\n{ppo_params}")
@@ -268,13 +318,6 @@ def main(argv):
268318
with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp:
269319
json.dump(env_cfg.to_dict(), fp, indent=4)
270320

271-
# Define policy parameters function for saving checkpoints
272-
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
273-
orbax_checkpointer = ocp.PyTreeCheckpointer()
274-
save_args = orbax_utils.save_args_from_target(params)
275-
path = ckpt_path / f"{current_step}"
276-
orbax_checkpointer.save(path, params, force=True, save_args=save_args)
277-
278321
training_params = dict(ppo_params)
279322
if "network_factory" in training_params:
280323
del training_params["network_factory"]
@@ -319,9 +362,9 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
319362
ppo.train,
320363
**training_params,
321364
network_factory=network_factory,
322-
policy_params_fn=policy_params_fn,
323365
seed=_SEED.value,
324366
restore_checkpoint_path=restore_checkpoint_path,
367+
save_checkpoint_path=ckpt_path,
325368
wrap_env_fn=None if _VISION.value else wrapper.wrap_for_brax_training,
326369
num_eval_envs=num_eval_envs,
327370
)
@@ -341,18 +384,55 @@ def progress(num_steps, metrics):
341384
for key, value in metrics.items():
342385
writer.add_scalar(key, value, num_steps)
343386
writer.flush()
344-
345-
print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}")
387+
if _RUN_EVALS.value:
388+
print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}")
389+
if _LOG_TRAINING_METRICS.value:
390+
if "episode/sum_reward" in metrics:
391+
print(
392+
f"{num_steps}: mean episode"
393+
f" reward={metrics['episode/sum_reward']:.3f}"
394+
)
346395

347396
# Load evaluation environment
348397
eval_env = (
349398
None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg)
350399
)
351400

401+
policy_params_fn = lambda *args: None
402+
if _RSCOPE_ENVS.value:
403+
# Interactive visualisation of policy checkpoints
404+
from rscope import brax as rscope_utils
405+
406+
if not _VISION.value:
407+
rscope_env = registry.load(_ENV_NAME.value, config=env_cfg)
408+
rscope_env = wrapper.wrap_for_brax_training(
409+
rscope_env,
410+
episode_length=ppo_params.episode_length,
411+
action_repeat=ppo_params.action_repeat,
412+
randomization_fn=training_params.get("randomization_fn"),
413+
)
414+
else:
415+
rscope_env = env
416+
417+
rscope_handle = rscope_utils.BraxRolloutSaver(
418+
rscope_env,
419+
ppo_params,
420+
_VISION.value,
421+
_RSCOPE_ENVS.value,
422+
_DETERMINISTIC_RSCOPE.value,
423+
jax.random.PRNGKey(_SEED.value),
424+
rscope_fn,
425+
)
426+
427+
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
428+
rscope_handle.set_make_policy(make_policy)
429+
rscope_handle.dump_rollout(params)
430+
352431
# Train or load the model
353432
make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter
354433
environment=env,
355434
progress_fn=progress,
435+
policy_params_fn=policy_params_fn,
356436
eval_env=None if _VISION.value else eval_env,
357437
)
358438

mujoco_playground/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mujoco_playground._src.mjx_env import render_array
2626
from mujoco_playground._src.mjx_env import State
2727
from mujoco_playground._src.mjx_env import step
28+
2829
# pylint: enable=g-importing-member
2930

3031
__all__ = [

mujoco_playground/_src/dm_control_suite/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def load(
150150
An instance of the environment.
151151
"""
152152
if env_name not in _envs:
153-
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
153+
raise ValueError(
154+
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
155+
)
154156
config = config or get_default_config(env_name)
155157
return _envs[env_name](config=config, config_overrides=config_overrides)

mujoco_playground/_src/dm_control_suite/cartpole.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ def __init__(
9090
self._get_reward = self._dense_reward
9191

9292
self._xml_path = _XML_PATH.as_posix()
93+
self._model_assets = common.get_assets()
9394
self._mj_model = mujoco.MjModel.from_xml_string(
94-
_XML_PATH.read_text(), common.get_assets()
95+
_XML_PATH.read_text(), self._model_assets
9596
)
9697
self._mj_model.opt.timestep = self.sim_dt
9798
self._mjx_model = mjx.put_model(self._mj_model)

mujoco_playground/_src/locomotion/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def load(
182182
An instance of the environment.
183183
"""
184184
if env_name not in _envs:
185-
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
185+
raise ValueError(
186+
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
187+
)
186188
config = config or get_default_config(env_name)
187189
return _envs[env_name](config=config, config_overrides=config_overrides)
188190

mujoco_playground/_src/locomotion/t1/randomize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from mujoco import mjx
1919
import numpy as np
2020

21-
2221
FLOOR_GEOM_ID = 0
2322
TORSO_BODY_ID = 1
2423
ANKLE_JOINT_IDS = np.array([[21, 22, 27, 28]])
@@ -30,7 +29,7 @@ def rand_dynamics(rng):
3029
# Floor friction: =U(0.4, 1.0).
3130
rng, key = jax.random.split(rng)
3231
geom_friction = model.geom_friction.at[FLOOR_GEOM_ID, 0].set(
33-
jax.random.uniform(key, minval=0.2, maxval=.6)
32+
jax.random.uniform(key, minval=0.2, maxval=0.6)
3433
)
3534

3635
rng, key = jax.random.split(rng)

mujoco_playground/_src/manipulation/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def load(
111111
An instance of the environment.
112112
"""
113113
if env_name not in _envs:
114-
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
114+
raise ValueError(
115+
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
116+
)
115117
config = config or get_default_config(env_name)
116118
return _envs[env_name](config=config, config_overrides=config_overrides)
117119

mujoco_playground/_src/manipulation/franka_emika_panda/panda.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def __init__(
7171

7272
self._xml_path = xml_path.as_posix()
7373
xml = xml_path.read_text()
74-
mj_model = mujoco.MjModel.from_xml_string(xml, assets=get_assets())
74+
self._model_assets = get_assets()
75+
mj_model = mujoco.MjModel.from_xml_string(xml, assets=self._model_assets)
7576
mj_model.opt.timestep = self.sim_dt
7677

7778
self._mj_model = mj_model
@@ -108,7 +109,7 @@ def _post_init(self, obj_name: str, keyframe: str):
108109

109110
@property
110111
def xml_path(self) -> str:
111-
raise self._xml_path
112+
return self._xml_path
112113

113114
@property
114115
def action_size(self) -> int:

mujoco_playground/_src/mjx_env.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,16 @@ def observation_size(self) -> ObservationSize:
273273
return jax.tree_util.tree_map(lambda x: x.shape, obs)
274274
return obs.shape[-1]
275275

276+
@property
277+
def model_assets(self) -> Dict[str, Any]:
278+
"""Dictionary of model assets to use with MjModel.from_xml_path"""
279+
if hasattr(self, "_model_assets"):
280+
return self._model_assets
281+
raise NotImplementedError(
282+
"_model_assets not defined for this environment"
283+
"see cartpole.py for an example."
284+
)
285+
276286
def render(
277287
self,
278288
trajectory: List[State],

0 commit comments

Comments
 (0)