diff --git a/slime/backends/megatron_utils/actor.py b/slime/backends/megatron_utils/actor.py index 675d5bc61..83259de10 100644 --- a/slime/backends/megatron_utils/actor.py +++ b/slime/backends/megatron_utils/actor.py @@ -78,7 +78,11 @@ def init( Timer().start("train_wait") return - start_rollout_id = loaded_rollout_id + 1 + expected_start_rollout_id = 0 if loaded_rollout_id == 0 else (loaded_rollout_id + 1) + assert ( + args.start_rollout_id == expected_start_rollout_id + ), f"{args.start_rollout_id=} {expected_start_rollout_id=}" + self.weights = {"actor": {}} self.update_cpu_params_dict(self.weights["actor"]) @@ -137,7 +141,6 @@ def init( self.prof.start() Timer().start("train_wait") - return start_rollout_id @torch.no_grad() def update_cpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None: diff --git a/slime/ray/placement_group.py b/slime/ray/placement_group.py index 0317644eb..c751ab5d3 100644 --- a/slime/ray/placement_group.py +++ b/slime/ray/placement_group.py @@ -140,13 +140,7 @@ def create_training_models(args, pgs, rollout_manager, wandb_run_id): else: critic_model = None - start_rollout_ids = ray.get( - actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss) - ) - - assert len(set(start_rollout_ids)) == 1 - if args.start_rollout_id is None: - args.start_rollout_id = start_rollout_ids[0] + ray.get(actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)) if args.use_critic: ray.get(critic_init_handle) diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index b65fcdc15..6a12f9cca 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -8,6 +8,7 @@ from slime.backends.sglang_utils.arguments import add_sglang_arguments from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args +from slime.utils.checkpoint_utils import get_latest_checkpointed_iteration def reset_arg(parser, name, **kwargs): @@ -1126,12 +1127,8 @@ def slime_validate_args(args): "please make sure it is a valid megatron checkpoint directory." ) - # TODO: During loading, we need to set the start_rollout_id here. - if ( - args.load is None - or not os.path.exists(args.load) - or not os.path.exists(os.path.join(args.load, "latest_checkpointed_iteration.txt")) - ): + load_ckpt_iter = get_latest_checkpointed_iteration(args) + if load_ckpt_iter is None: args.no_load_optim = True args.no_load_rng = True args.finetune = True @@ -1139,6 +1136,8 @@ def slime_validate_args(args): if args.ref_ckpt_step is not None: args.ckpt_step = args.ref_ckpt_step args.start_rollout_id = 0 + else: + args.start_rollout_id = load_ckpt_iter + 1 if args.eval_interval is not None: assert args.eval_prompt_data is not None, "eval_prompt_data must be set when eval_interval is set" diff --git a/slime/utils/checkpoint_utils.py b/slime/utils/checkpoint_utils.py new file mode 100644 index 000000000..2ad1a0526 --- /dev/null +++ b/slime/utils/checkpoint_utils.py @@ -0,0 +1,19 @@ +from pathlib import Path +from typing import Optional + + +def get_latest_checkpointed_iteration(args) -> Optional[int]: + """ + :param args: The Megatron arguments + """ + if (x := args.ckpt_step) is not None: + return x + + if args.load is None: + return None + + path_txt = Path(args.load) / "latest_checkpointed_iteration.txt" + if not path_txt.exists(): + return None + + return int(path_txt.read_text())