Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions slime/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def init(
Timer().start("train_wait")
return

start_rollout_id = loaded_rollout_id + 1
expected_start_rollout_id = 0 if loaded_rollout_id == 0 else (loaded_rollout_id + 1)
assert (
args.start_rollout_id == expected_start_rollout_id
), f"{args.start_rollout_id=} {expected_start_rollout_id=}"

self.weights = {"actor": {}}
self.update_cpu_params_dict(self.weights["actor"])

Expand Down Expand Up @@ -137,7 +141,6 @@ def init(
self.prof.start()

Timer().start("train_wait")
return start_rollout_id

@torch.no_grad()
def update_cpu_params_dict(self, params_dict: Dict[str, torch.Tensor]) -> None:
Expand Down
8 changes: 1 addition & 7 deletions slime/ray/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,7 @@ def create_training_models(args, pgs, rollout_manager, wandb_run_id):
else:
critic_model = None

start_rollout_ids = ray.get(
actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
)

assert len(set(start_rollout_ids)) == 1
if args.start_rollout_id is None:
args.start_rollout_id = start_rollout_ids[0]
ray.get(actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss))

if args.use_critic:
ray.get(critic_init_handle)
Expand Down
11 changes: 5 additions & 6 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from slime.backends.sglang_utils.arguments import add_sglang_arguments
from slime.backends.sglang_utils.arguments import validate_args as sglang_validate_args
from slime.utils.checkpoint_utils import get_latest_checkpointed_iteration


def reset_arg(parser, name, **kwargs):
Expand Down Expand Up @@ -1126,19 +1127,17 @@ def slime_validate_args(args):
"please make sure it is a valid megatron checkpoint directory."
)

# TODO: During loading, we need to set the start_rollout_id here.
if (
args.load is None
or not os.path.exists(args.load)
or not os.path.exists(os.path.join(args.load, "latest_checkpointed_iteration.txt"))
):
load_ckpt_iter = get_latest_checkpointed_iteration(args)
if load_ckpt_iter is None:
args.no_load_optim = True
args.no_load_rng = True
args.finetune = True
args.load = args.ref_load
if args.ref_ckpt_step is not None:
args.ckpt_step = args.ref_ckpt_step
args.start_rollout_id = 0
else:
args.start_rollout_id = load_ckpt_iter + 1

if args.eval_interval is not None:
assert args.eval_prompt_data is not None, "eval_prompt_data must be set when eval_interval is set"
Expand Down
19 changes: 19 additions & 0 deletions slime/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pathlib import Path
from typing import Optional


def get_latest_checkpointed_iteration(args) -> Optional[int]:
"""
:param args: The Megatron arguments
"""
if (x := args.ckpt_step) is not None:
return x

if args.load is None:
return None

path_txt = Path(args.load) / "latest_checkpointed_iteration.txt"
if not path_txt.exists():
return None

return int(path_txt.read_text())
Loading