99from slime .ray .placement_group import create_placement_groups , create_rollout_manager , create_training_models
1010from slime .utils .arguments import parse_args
1111from slime .utils .logging_utils import configure_logger
12+ from slime .utils .misc import should_run_periodic_action
1213from slime .utils .tracking_utils import init_tracking
1314
1415
@@ -61,7 +62,6 @@ def onload_rollout():
6162 # train loop.
6263 # note that for async training, one can change the position of the sync operation(ray.get).
6364 for rollout_id in range (args .start_rollout_id , args .num_rollout ):
64- # TODO extract the duplicated eval logic
6565 if args .eval_interval is not None and rollout_id == 0 :
6666 ray .get (rollout_manager .eval .remote (rollout_id ))
6767
@@ -78,10 +78,7 @@ def onload_rollout():
7878 else :
7979 ray .get (actor_model .async_train (rollout_id , rollout_data_ref ))
8080
81- if args .save_interval is not None and (
82- (rollout_id + 1 ) % args .save_interval == 0
83- or (num_rollout_per_epoch is not None and (rollout_id + 1 ) % num_rollout_per_epoch == 0 )
84- ):
81+ if should_run_periodic_action (rollout_id , args .save_interval , num_rollout_per_epoch ):
8582 if (not args .use_critic ) or (rollout_id >= args .num_critic_only_steps ):
8683 actor_model .save_model (rollout_id )
8784 if args .use_critic :
@@ -98,10 +95,7 @@ def onload_rollout():
9895 ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_CUDA_GRAPH ]))
9996 ray .get (rollout_manager .onload .remote (tags = [GPU_MEMORY_TYPE_KV_CACHE ]))
10097
101- if args .eval_interval is not None and (
102- (rollout_id + 1 ) % args .eval_interval == 0
103- or (num_rollout_per_epoch is not None and (rollout_id + 1 ) % num_rollout_per_epoch == 0 )
104- ):
98+ if should_run_periodic_action (rollout_id , args .eval_interval , num_rollout_per_epoch ):
10599 ray .get (rollout_manager .eval .remote (rollout_id ))
106100
107101 ray .get (rollout_manager .dispose .remote ())
0 commit comments