Skip to content

Commit 6b6403a

Browse files
committed
[feat] add --critic-lr and --num-critic-only-steps
1 parent 3e043e4 commit 6b6403a

File tree

9 files changed

+57
-22
lines changed

9 files changed

+57
-22
lines changed

slime/backends/megatron_utils/actor.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def init(self, args, role, wandb_run_id, with_ref=False):
5353
if role == "critic":
5454
self.args.load = self.args.critic_load
5555
self.args.save = self.args.critic_save
56+
self.args.lr = self.args.critic_lr
5657

5758
(self.model, self.optimizer, self.opt_param_scheduler, loaded_rollout_id) = initialize_model_and_optimizer(
5859
args, role
@@ -254,11 +255,16 @@ def train_critic(self, rollout_id, rollout_data):
254255
self.model,
255256
data_iterator,
256257
num_microbatches,
257-
)
258-
values = [value.squeeze(-1) for value in values["values"]]
259-
values, log_probs, ref_log_probs = sync_actor_critic_data(
260-
self.args, values, None, None, self._actor_critic_groups
261-
)
258+
)["values"]
259+
260+
if rollout_id < self.args.num_critic_only_steps:
261+
# we will only use the shape of log_probs in this situation
262+
log_probs = values
263+
ref_log_probs = values
264+
else:
265+
values, log_probs, ref_log_probs = sync_actor_critic_data(
266+
self.args, values, None, None, self._actor_critic_groups
267+
)
262268

263269
rollout_data.update(
264270
{

slime/backends/megatron_utils/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import torch
66
import torch.distributed as dist
77
import torch.nn.functional as F
8-
import wandb
98
from megatron.core import mpu
109
from megatron.core.packed_seq_params import PackedSeqParams
1110

11+
import wandb
1212
from slime.utils.data import get_minimum_num_micro_batch_size
1313
from slime.utils.flops_utils import calculate_fwd_flops
1414
from slime.utils.seqlen_balancing import get_seqlen_balanced_partitions
@@ -243,7 +243,7 @@ def log_rollout_data(rollout_id, args, rollout_data):
243243
# NOTE: Here we have to do the clone().detach(), otherwise the tensor will be
244244
# modified in place and will cause problem for the next rollout.
245245
val = torch.cat(val).clone().detach()
246-
if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages"]:
246+
if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]:
247247
sum_of_sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks)
248248
val = cp_size * sum_of_sample_mean(val) / len(loss_masks)
249249
else:

slime/backends/megatron_utils/loss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def get_values(
117117
response_lengths=response_lengths,
118118
):
119119
assert logits_chunk.size(-1) == 1, f"{logits_chunk.shape}"
120-
value_list.append(logits_chunk)
120+
value_list.append(logits_chunk.squeeze(-1))
121121

122122
return {
123123
"values": value_list,
@@ -366,19 +366,22 @@ def value_loss_function(args, batch, logits, sum_of_sample_mean):
366366

367367
returns = torch.cat(batch["returns"], dim=0)
368368

369+
values_clipfrac = torch.abs(values - old_values) > args.value_clip
369370
values_clipped = old_values + (values - old_values).clamp(-args.value_clip, args.value_clip)
370371
surr1 = (values_clipped - returns) ** 2
371372
surr2 = (values - returns) ** 2
372373
loss = torch.max(surr1, surr2)
373374

374375
loss = sum_of_sample_mean(loss)
376+
values_clipfrac = sum_of_sample_mean(values_clipfrac.float())
375377

376378
# make sure the gradient could backprop correctly.
377379
if values.numel() == 0:
378380
loss += 0 * values.sum()
379381

380382
reported_loss = {
381383
"value_loss": loss.clone().detach(),
384+
"value_clipfrac": values_clipfrac.clone().detach(),
382385
}
383386

384387
return loss, reported_loss

slime/backends/megatron_utils/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,10 @@ def train(rollout_id, model, optimizer, opt_param_scheduler, data_iterator, num_
462462
for key, val in loss_dict.items()
463463
}
464464
log_dict["train/grad_norm"] = grad_norm
465+
role = getattr(model[0], "role", "actor")
466+
role_tag = "" if role == "actor" else f"{role}-"
465467
for param_group_id, param_group in enumerate(optimizer.param_groups):
466-
log_dict[f"train/lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group)
468+
log_dict[f"train/{role_tag}lr-pg_{param_group_id}"] = opt_param_scheduler.get_lr(param_group)
467469

468470
if args.use_wandb:
469471
log_dict["train/step"] = accumulated_step_id
@@ -475,7 +477,7 @@ def train(rollout_id, model, optimizer, opt_param_scheduler, data_iterator, num_
475477
if accumulated_step_id == 0 and "train/kl_loss" in log_dict:
476478
assert log_dict["train/kl_loss"] == 0.0
477479

478-
print(f"step {accumulated_step_id}: {log_dict}")
480+
print(f"{role_tag}step {accumulated_step_id}: {log_dict}")
479481
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
480482
if pre_hook_enabled:
481483
disable_forward_pre_hook(model)
@@ -501,6 +503,7 @@ def save(iteration, model, optimizer, opt_param_scheduler):
501503

502504
def initialize_model_and_optimizer(args, role: str = "actor"):
503505
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(args, role)
506+
setattr(model[0], "role", role)
504507
clear_memory()
505508
iteration, _ = load_checkpoint(
506509
model,

slime/backends/megatron_utils/model_provider.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ def __init__(
3030
if self.sequence_parallel:
3131
self.weight.sequence_parallel = True
3232

33+
self.weight.data.normal_(mean=0.0, std=0.02)
34+
if bias:
35+
self.bias.data.zero_()
36+
3337
def forward(
3438
self,
3539
input_,

slime/utils/arguments.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,12 @@ def add_algo_arguments(parser):
518518
reset_arg(parser, "--seed", type=int, default=1234)
519519
reset_arg(parser, "--clip-grad", type=float, default=1.0)
520520
reset_arg(parser, "--calculate-per-token-loss", action="store_true")
521+
reset_arg(parser, "--lr", type=float, default=1e-6)
521522

523+
parser.add_argument("--num-critic-only-steps", type=int, default=0, help="Number of critic only steps")
522524
parser.add_argument("--critic-load", type=str, default=None, help="The checkpoint for critic model.")
523525
parser.add_argument("--critic-save", type=str, default=None, help="The checkpoint for critic model.")
526+
parser.add_argument("--critic-lr", type=float, default=None, help="The lr for critic model")
524527

525528
parser.add_argument("--eps-clip", type=float, default=0.2, help="PPO clip range")
526529
parser.add_argument("--eps-clip-high", type=float, default=None, help="PPO clip upper range")
@@ -984,9 +987,6 @@ def slime_validate_args(args):
984987
args.ckpt_step = args.ref_ckpt_step
985988
args.start_rollout_id = 0
986989

987-
if args.critic_load is None:
988-
args.critic_load = args.load
989-
990990
if args.eval_interval is not None:
991991
assert args.eval_prompt_data is not None, "eval_prompt_data must be set when eval_interval is set"
992992
if len(args.eval_prompt_data) == 1:
@@ -1032,6 +1032,10 @@ def slime_validate_args(args):
10321032
args.critic_num_gpus_per_node = args.actor_num_gpus_per_node
10331033
if args.critic_num_nodes is None:
10341034
args.critic_num_nodes = args.actor_num_nodes
1035+
if args.critic_load is None:
1036+
args.critic_load = args.load
1037+
if args.critic_lr is None:
1038+
args.critic_lr = args.lr
10351039

10361040
if args.debug_rollout_only:
10371041
if args.colocate and args.rollout_num_gpus is None:

slime/utils/http_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def init_http_client(concurrency: int):
8383
if _http_client is None:
8484
_http_client = httpx.AsyncClient(
8585
limits=httpx.Limits(max_connections=concurrency),
86-
timeout=httpx.Timeout(None),
86+
timeout=httpx.Timeout(None, connect=5.0),
8787
)
8888

8989

@@ -113,7 +113,6 @@ async def post(url, payload, max_retries=60):
113113

114114

115115
async def get(url):
116-
# never timeout
117116
response = await _http_client.get(url)
118117
response.raise_for_status()
119118
output = response.json()

train.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def train(args):
7070

7171
if args.use_critic:
7272
critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref)
73-
74-
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
75-
76-
if args.use_critic:
73+
if rollout_id >= args.num_critic_only_steps:
74+
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
7775
ray.get(critic_train_handle)
76+
else:
77+
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
7878

7979
if args.save_interval is not None and (
8080
(rollout_id + 1) % args.save_interval == 0
@@ -85,9 +85,12 @@ def train(args):
8585
ray.get(rollout_manager.save.remote(rollout_id))
8686

8787
if args.offload:
88-
ray.get(actor_model.async_offload())
8988
if args.use_critic:
9089
ray.get(critic_model.async_offload())
90+
if rollout_id >= args.num_critic_only_steps:
91+
ray.get(actor_model.async_offload())
92+
else:
93+
ray.get(actor_model.async_offload())
9194

9295
ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS]))
9396

train_async.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,13 @@ def train(args):
2828
assert args.num_rollout > 0
2929

3030
# sync the initialization (model initalization, load checkpoint, etc.)
31-
# Note that we initialize it earlier as megatron ckpt loading may have really large peak memory usage.
31+
if args.use_critic:
32+
critic_init_handle = critic_model.async_init(args, role="critic", with_ref=False)
33+
3234
start_rollout_ids = ray.get(
3335
actor_model.async_init(args, role="actor", with_ref=args.kl_coef != 0 or args.use_kl_loss)
3436
)
37+
3538
assert len(set(start_rollout_ids)) == 1
3639
if args.start_rollout_id is None:
3740
args.start_rollout_id = start_rollout_ids[0]
@@ -42,6 +45,10 @@ def train(args):
4245
# initialize the connection for weight update during training
4346
ray.get(actor_model.async_init_weight_update_connections(rollout_manager))
4447

48+
if args.use_critic:
49+
ray.get(critic_init_handle)
50+
ray.get(actor_model.async_connect(critic_model))
51+
4552
# always update weight first so that sglang has the loaded weights from training.
4653
ray.get(actor_model.async_update_weights())
4754

@@ -56,7 +63,13 @@ def train(args):
5663
if rollout_id + 1 < args.num_rollout:
5764
rollout_data_next_future = rollout_manager.generate.remote(rollout_id + 1)
5865

59-
ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
66+
if args.use_critic:
67+
critic_train_handle = critic_model.async_train(rollout_id, rollout_data_curr_ref)
68+
if rollout_id >= args.num_critic_only_steps:
69+
ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
70+
ray.get(critic_train_handle)
71+
else:
72+
ray.get(actor_model.async_train(rollout_id, rollout_data_curr_ref))
6073

6174
if args.save_interval is not None and (
6275
(rollout_id + 1) % args.save_interval == 0

0 commit comments

Comments
 (0)