Skip to content

Commit b608398

Browse files
authored
Merge branch 'main' into feature/register
2 parents 14c0ba4 + 843c94e commit b608398

File tree

9 files changed

+50
-27
lines changed

9 files changed

+50
-27
lines changed

slime/backends/megatron_utils/actor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def init(self, args, role, wandb_run_id, with_ref=False):
5454
if role == "critic":
5555
self.args.load = self.args.critic_load
5656
self.args.save = self.args.critic_save
57+
self.args.lr = self.args.critic_lr
5758

5859
(self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer(
5960
args, role
@@ -256,11 +257,16 @@ def train_critic(self, rollout_id, rollout_data):
256257
self.model,
257258
data_iterator,
258259
num_microbatches,
259-
)
260-
values = [value.squeeze(-1) for value in values["values"]]
261-
values, log_probs, ref_log_probs = sync_actor_critic_data(
262-
self.args, values, None, None, self._actor_critic_groups
263-
)
260+
)["values"]
261+
262+
if rollout_id < self.args.num_critic_only_steps:
263+
# we will only use the shape of log_probs in this situation
264+
log_probs = values
265+
ref_log_probs = values
266+
else:
267+
values, log_probs, ref_log_probs = sync_actor_critic_data(
268+
self.args, values, None, None, self._actor_critic_groups
269+
)
264270

265271
rollout_data.update(
266272
{

slime/backends/megatron_utils/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def log_rollout_data(rollout_id, args, rollout_data):
243243
# NOTE: Here we have to do the clone().detach(), otherwise the tensor will be
244244
# modified in place and will cause problem for the next rollout.
245245
val = torch.cat(val).clone().detach()
246-
if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages"]:
246+
if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]:
247247
sum_of_sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks)
248248
val = cp_size * sum_of_sample_mean(val) / len(loss_masks)
249249
else:

slime/backends/megatron_utils/loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def get_values(
117117
response_lengths=response_lengths,
118118
):
119119
assert logits_chunk.size(-1) == 1, f"{logits_chunk.shape}"
120-
value_list.append(logits_chunk)
120+
value_list.append(logits_chunk.squeeze(-1))
121121

122122
return {
123123
"values": value_list,
@@ -366,19 +366,22 @@ def value_loss_function(args, batch, logits, sum_of_sample_mean):
366366

367367
returns = torch.cat(batch["returns"], dim=0)
368368

369+
values_clipfrac = torch.abs(values - old_values) > args.value_clip
369370
values_clipped = old_values + (values - old_values).clamp(-args.value_clip, args.value_clip)
370371
surr1 = (values_clipped - returns) ** 2
371372
surr2 = (values - returns) ** 2
372373
loss = torch.max(surr1, surr2)
373374

374375
loss = sum_of_sample_mean(loss)
376+
values_clipfrac = sum_of_sample_mean(values_clipfrac.float())
375377

376378
# make sure the gradient could backprop correctly.
377379
if values.numel() == 0:
378380
loss += 0 * values.sum()
379381

380382
reported_loss = {
381383
"value_loss": loss.clone().detach(),
384+
"value_clipfrac": values_clipfrac.clone().detach(),
382385
}
383386

384387
return loss, reported_loss

slime/backends/megatron_utils/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,10 @@ def train(rollout_id, model, optimizer, opt_param_scheduler, data_iterator, num_
462462
for key, val in loss_dict.items()
463463
}
464464
log_dict["train/grad_norm"] = grad_norm
465+
role = getattr(model[0], "role", "actor")
466+
role_tag = "" if role == "actor" else f"{role}-"
465467
for param_group_id, param_group in enumerate(optimizer.param_groups):
466-
log_dict[f"train/lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group)
468+
log_dict[f"train/{role_tag}lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group)
467469

468470
if args.use_wandb:
469471
log_dict["train/step"] = accumulated_step_id
@@ -475,7 +477,7 @@ def train(rollout_id, model, optimizer, opt_param_scheduler, data_iterator, num_
475477
if accumulated_step_id == 0 and "train/kl_loss" in log_dict:
476478
assert log_dict["train/kl_loss"] == 0.0
477479

478-
print(f"step {accumulated_step_id}: {log_dict}")
480+
print(f"{role_tag}step {accumulated_step_id}: {log_dict}")
479481
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
480482
if pre_hook_enabled:
481483
disable_forward_pre_hook(model)
@@ -501,6 +503,7 @@ def save(iteration, model, optimizer, opt_param_scheduler):
501503

502504
def initialize_model_and_optimizer(args, role: str = "actor"):
503505
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(args, role)
506+
setattr(model[0], "role", role)
504507
clear_memory()
505508
iteration, _ = load_checkpoint(
506509
model,

slime/backends/megatron_utils/model_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __init__(
3030
if self.sequence_parallel:
3131
self.weight.sequence_parallel = True
3232

33+
self.weight.data.normal_(mean=0.0, std=0.02)
34+
if bias:
35+
self.bias.data.zero_()
36+
3337
def forward(
3438
self,
3539
input_,

slime/utils/arguments.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,12 @@ def add_algo_arguments(parser):
518518
reset_arg(parser, "--seed", type=int, default=1234)
519519
reset_arg(parser, "--clip-grad", type=float, default=1.0)
520520
reset_arg(parser, "--calculate-per-token-loss", action="store_true")
521+
reset_arg(parser, "--lr", type=float, default=1e-6)
521522

523+
parser.add_argument("--num-critic-only-steps", type=int, default=0, help="Number of critic only steps")
522524
parser.add_argument("--critic-load", type=str, default=None, help="The checkpoint for critic model.")
523525
parser.add_argument("--critic-save", type=str, default=None, help="The checkpoint for critic model.")
526+
parser.add_argument("--critic-lr", type=float, default=None, help="The lr for critic model")
524527

525528
parser.add_argument("--eps-clip", type=float, default=0.2, help="PPO clip range")
526529
parser.add_argument("--eps-clip-high", type=float, default=None, help="PPO clip upper range")
@@ -984,9 +987,6 @@ def slime_validate_args(args):
984987
args.ckpt_step = args.ref_ckpt_step
985988
args.start_rollout_id = 0
986989

987-
if args.critic_load is None:
988-
args.critic_load = args.load
989-
990990
if args.eval_interval is not None:
991991
assert args.eval_prompt_data is not None, "eval_prompt_data must be set when eval_interval is set"
992992
if len(args.eval_prompt_data) == 1:
@@ -1032,6 +1032,10 @@ def slime_validate_args(args):
10321032
args.critic_num_gpus_per_node = args.actor_num_gpus_per_node
10331033
if args.critic_num_nodes is None:
10341034
args.critic_num_nodes = args.actor_num_nodes
1035+
if args.critic_load is None:
1036+
args.critic_load = args.load
1037+
if args.critic_lr is None:
1038+
args.critic_lr = args.lr
10351039

10361040
if args.debug_rollout_only:
10371041
if args.colocate and args.rollout_num_gpus is None:

slime/utils/http_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def init_http_client(concurrency: int):
8383
if _http_client is None:
8484
_http_client = httpx.AsyncClient(
8585
limits=httpx.Limits(max_connections=concurrency),
86-
timeout=httpx.Timeout(None),
86+
timeout=httpx.Timeout(None, connect=5.0),
8787
)
8888

8989

@@ -113,7 +113,6 @@ async def post(url, payload, max_retries=60):
113113

114114

115115
async def get(url):
116-
# never timeout
117116
response = await _http_client.get(url)
118117
response.raise_for_status()
119118
output = response.json()

train.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,10 @@ def train(args):
7777
ray.get(rollout_manager.offload.remote())
7878

7979
if args.use_critic:
80-
ray.get(
81-
[
82-
actor_model.async_train(rollout_id, rollout_data_ref),
83-
critic_model.async_train(rollout_id, rollout_data_ref),
84-
]
85-
)
80+
critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)
81+
if rollout_id >= args.num_critic_only_steps:
82+
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
83+
ray.get(critic_train_handle)
8684
else:
8785
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
8886

@@ -95,9 +93,12 @@ def train(args):
9593
ray.get(rollout_manager.save.remote(rollout_id))
9694

9795
if args.offload:
98-
ray.get(actor_model.async_offload())
9996
if args.use_critic:
10097
ray.get(critic_model.async_offload())
98+
if rollout_id >= args.num_critic_only_steps:
99+
ray.get(actor_model.async_offload())
100+
else:
101+
ray.get(actor_model.async_offload())
101102

102103
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
103104

train_async.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def train(args):
4444
start_rollout_ids = ray.get(
4545
actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
4646
)
47+
4748
assert len(set(start_rollout_ids)) == 1
4849
if args.start_rollout_id is None:
4950
args.start_rollout_id = start_rollout_ids[0]
@@ -55,6 +56,10 @@ def train(args):
5556
ray.get(critic_init_handle)
5657
ray.get(actor_model.async_connect(critic_model))
5758

59+
if args.use_critic:
60+
ray.get(critic_init_handle)
61+
ray.get(actor_model.async_connect(critic_model))
62+
5863
# always update weight first so that sglang has the loaded weights from training.
5964
ray.get(actor_model.async_update_weights())
6065

@@ -70,12 +75,10 @@ def train(args):
7075
rollout_data_next_future = rollout_manager.generate.remote(rollout_id + 1)
7176

7277
if args.use_critic:
73-
ray.get(
74-
[
75-
critic_model.async_train(rollout_id, rollout_data_curr_ref),
76-
actor_model.async_train(rollout_id, rollout_data_curr_ref),
77-
]
78-
)
78+
critic_train_handle = critic_model.async_train(rollout_id, rollout_data_curr_ref)
79+
if rollout_id >= args.num_critic_only_steps:
80+
ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
81+
ray.get(critic_train_handle)
7982
else:
8083
ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
8184

0 commit comments

Comments
 (0)