132132 "policy_obs_key" , "state" , "Policy obs key"
133133)
134134_VALUE_OBS_KEY = flags .DEFINE_string ("value_obs_key" , "state" , "Value obs key" )
135+ _RSCOPE_ENVS = flags .DEFINE_integer (
136+ "rscope_envs" ,
137+ None ,
138+ "Number of parallel environment rollouts to save for the rscope viewer" ,
139+ )
140+ _DETERMINISTIC_RSCOPE = flags .DEFINE_boolean (
141+ "deterministic_rscope" ,
142+ True ,
143+ "Run deterministic rollouts for the rscope viewer" ,
144+ )
145+ _RUN_EVALS = flags .DEFINE_boolean (
146+ "run_evals" ,
147+ True ,
148+ "Run evaluation rollouts between policy updates." ,
149+ )
150+ _LOG_TRAINING_METRICS = flags .DEFINE_boolean (
151+ "log_training_metrics" ,
152+ False ,
153+ "Whether to log training metrics and callback to progress_fn. Significantly"
154+ " slows down training if too frequent." ,
155+ )
156+ _TRAINING_METRICS_STEPS = flags .DEFINE_integer (
157+ "training_metrics_steps" ,
158+ 1_000_000 ,
159+ "Number of steps between logging training metrics. Increase if training"
160+ " experiences slowdown." ,
161+ )
135162
136163
137164def get_rl_config (env_name : str ) -> config_dict .ConfigDict :
@@ -151,6 +178,24 @@ def get_rl_config(env_name: str) -> config_dict.ConfigDict:
151178 raise ValueError (f"Env { env_name } not found in { registry .ALL_ENVS } ." )
152179
153180
181+ def rscope_fn (full_states , obs , rew , done ):
182+ """
183+ All arrays are of shape (unroll_length, rscope_envs, ...)
184+ full_states: dict with keys 'qpos', 'qvel', 'time', 'metrics'
185+ obs: nd.array or dict obs based on env configuration
186+ rew: nd.array rewards
187+ done: nd.array done flags
188+ """
189+ # Calculate cumulative rewards per episode, stopping at first done flag
190+ done_mask = jp .cumsum (done , axis = 0 )
191+ valid_rewards = rew * (done_mask == 0 )
192+ episode_rewards = jp .sum (valid_rewards , axis = 0 )
193+ print (
194+ "Collected rscope rollouts with reward"
195+ f" { episode_rewards .mean ():.3f} +- { episode_rewards .std ():.3f} "
196+ )
197+
198+
154199def main (argv ):
155200 """Run training and evaluation for the specified environment."""
156201
@@ -209,11 +254,16 @@ def main(argv):
209254 ppo_params .network_factory .policy_obs_key = _POLICY_OBS_KEY .value
210255 if _VALUE_OBS_KEY .present :
211256 ppo_params .network_factory .value_obs_key = _VALUE_OBS_KEY .value
212-
213257 if _VISION .value :
214258 env_cfg .vision = True
215259 env_cfg .vision_config .render_batch_size = ppo_params .num_envs
216260 env = registry .load (_ENV_NAME .value , config = env_cfg )
261+ if _RUN_EVALS .present :
262+ ppo_params .run_evals = _RUN_EVALS .value
263+ if _LOG_TRAINING_METRICS .present :
264+ ppo_params .log_training_metrics = _LOG_TRAINING_METRICS .value
265+ if _TRAINING_METRICS_STEPS .present :
266+ ppo_params .training_metrics_steps = _TRAINING_METRICS_STEPS .value
217267
218268 print (f"Environment Config:\n { env_cfg } " )
219269 print (f"PPO Training Parameters:\n { ppo_params } " )
@@ -268,13 +318,6 @@ def main(argv):
268318 with open (ckpt_path / "config.json" , "w" , encoding = "utf-8" ) as fp :
269319 json .dump (env_cfg .to_dict (), fp , indent = 4 )
270320
271- # Define policy parameters function for saving checkpoints
272- def policy_params_fn (current_step , make_policy , params ): # pylint: disable=unused-argument
273- orbax_checkpointer = ocp .PyTreeCheckpointer ()
274- save_args = orbax_utils .save_args_from_target (params )
275- path = ckpt_path / f"{ current_step } "
276- orbax_checkpointer .save (path , params , force = True , save_args = save_args )
277-
278321 training_params = dict (ppo_params )
279322 if "network_factory" in training_params :
280323 del training_params ["network_factory" ]
@@ -319,9 +362,9 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
319362 ppo .train ,
320363 ** training_params ,
321364 network_factory = network_factory ,
322- policy_params_fn = policy_params_fn ,
323365 seed = _SEED .value ,
324366 restore_checkpoint_path = restore_checkpoint_path ,
367+ save_checkpoint_path = ckpt_path ,
325368 wrap_env_fn = None if _VISION .value else wrapper .wrap_for_brax_training ,
326369 num_eval_envs = num_eval_envs ,
327370 )
@@ -341,18 +384,55 @@ def progress(num_steps, metrics):
341384 for key , value in metrics .items ():
342385 writer .add_scalar (key , value , num_steps )
343386 writer .flush ()
344-
345- print (f"{ num_steps } : reward={ metrics ['eval/episode_reward' ]:.3f} " )
387+ if _RUN_EVALS .value :
388+ print (f"{ num_steps } : reward={ metrics ['eval/episode_reward' ]:.3f} " )
389+ if _LOG_TRAINING_METRICS .value :
390+ if "episode/sum_reward" in metrics :
391+ print (
392+ f"{ num_steps } : mean episode"
393+ f" reward={ metrics ['episode/sum_reward' ]:.3f} "
394+ )
346395
347396 # Load evaluation environment
348397 eval_env = (
349398 None if _VISION .value else registry .load (_ENV_NAME .value , config = env_cfg )
350399 )
351400
401+ policy_params_fn = lambda * args : None
402+ if _RSCOPE_ENVS .value :
403+ # Interactive visualisation of policy checkpoints
404+ from rscope import brax as rscope_utils
405+
406+ if not _VISION .value :
407+ rscope_env = registry .load (_ENV_NAME .value , config = env_cfg )
408+ rscope_env = wrapper .wrap_for_brax_training (
409+ rscope_env ,
410+ episode_length = ppo_params .episode_length ,
411+ action_repeat = ppo_params .action_repeat ,
412+ randomization_fn = training_params .get ("randomization_fn" ),
413+ )
414+ else :
415+ rscope_env = env
416+
417+ rscope_handle = rscope_utils .BraxRolloutSaver (
418+ rscope_env ,
419+ ppo_params ,
420+ _VISION .value ,
421+ _RSCOPE_ENVS .value ,
422+ _DETERMINISTIC_RSCOPE .value ,
423+ jax .random .PRNGKey (_SEED .value ),
424+ rscope_fn ,
425+ )
426+
427+ def policy_params_fn (current_step , make_policy , params ): # pylint: disable=unused-argument
428+ rscope_handle .set_make_policy (make_policy )
429+ rscope_handle .dump_rollout (params )
430+
352431 # Train or load the model
353432 make_inference_fn , params , _ = train_fn ( # pylint: disable=no-value-for-parameter
354433 environment = env ,
355434 progress_fn = progress ,
435+ policy_params_fn = policy_params_fn ,
356436 eval_env = None if _VISION .value else eval_env ,
357437 )
358438
0 commit comments