Skip to content

Commit b543a50

Browse files
authored
sync internal features (THUDM#1192)
1 parent f638b1c commit b543a50

File tree

8 files changed

+92
-6
lines changed

8 files changed

+92
-6
lines changed

slime/backends/megatron_utils/actor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from slime.utils.timer import Timer, inverse_timer, timer
2626
from slime.utils.tracking_utils import init_tracking
2727
from slime.utils.types import RolloutBatch
28+
2829
from ...utils.profile_utils import TrainProfiler
2930
from ...utils.tensor_backper import TensorBackuper
3031
from .checkpoint import load_checkpoint
@@ -66,7 +67,6 @@ def init(
6667
if i == dist.get_rank():
6768
self.hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
6869
self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)
69-
7070
dist.barrier(group=get_gloo_group())
7171

7272
self.train_parallel_config = {

slime/backends/megatron_utils/data.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,64 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc
376376
if args.log_passrate:
377377
log_passrate(rollout_id, args, rollout_data)
378378

379+
if args.log_correct_samples:
380+
if mpu.get_tensor_model_parallel_rank() == 0 and mpu.is_pipeline_last_stage():
381+
cp_size = mpu.get_context_parallel_world_size()
382+
log_dict = {}
383+
response_lengths = rollout_data["response_lengths"]
384+
loss_masks = rollout_data["loss_masks"]
385+
total_lengths = rollout_data["total_lengths"]
386+
387+
def quantile(total_value, n_quantiles, data) -> dict:
388+
import math
389+
390+
assert n_quantiles > 1, f"n_quantiles({n_quantiles}) must be greater than 1."
391+
392+
quantiles = [((i + 1) / n_quantiles) for i in range(n_quantiles)]
393+
cut_points = [total_value * q for q in quantiles]
394+
cut_points[-1] = total_value
395+
396+
count = [0] * n_quantiles
397+
for d in data:
398+
for i, point in enumerate(cut_points):
399+
if d <= point:
400+
count[i] += 1
401+
break
402+
403+
total = sum(count) + 1e-9
404+
percentile = [c / total for c in count]
405+
406+
percentile = {f"p{min(math.ceil(q*100),100)}": p for q, p in zip(quantiles, percentile, strict=True)}
407+
return percentile
408+
409+
raw_rewards = rollout_data["raw_reward"]
410+
# Additional metrics for correct cases are calculated separately below.
411+
correct_response_lengths = []
412+
correct_total_lengths = []
413+
correct_loss_masks = []
414+
correct_entropy = []
415+
for i, raw_reward in enumerate(raw_rewards):
416+
if raw_reward == 1:
417+
correct_response_lengths.append(response_lengths[i])
418+
correct_total_lengths.append(total_lengths[i])
419+
correct_loss_masks.append(loss_masks[i])
420+
correct_entropy.append(-rollout_data["log_probs"][i])
421+
num_correct_responses = len(correct_total_lengths)
422+
rollout_data["correct_response_lengths"] = correct_response_lengths
423+
correct_response_length_percentile = quantile(
424+
args.rollout_max_response_len, 4, rollout_data["correct_response_lengths"]
425+
)
426+
for p, val in correct_response_length_percentile.items():
427+
rollout_data[f"correct_length/{p}"] = [val] * num_correct_responses
428+
if len(correct_entropy) > 0:
429+
sum_of_sample_mean = get_sum_of_sample_mean(
430+
correct_total_lengths, correct_response_lengths, correct_loss_masks
431+
)
432+
correct_entropy = sum_of_sample_mean(torch.cat(correct_entropy, dim=0))
433+
rollout_data["correct_entropy"] = [correct_entropy.item()] * num_correct_responses
434+
else:
435+
rollout_data["correct_entropy"] = [0] * num_correct_responses
436+
379437

380438
def log_multi_turn_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatch) -> None:
381439
"""

slime/backends/megatron_utils/initialize.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import torch
66
from megatron.core import mpu, tensor_parallel
7+
from megatron.core.config import set_experimental_flag
78
from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator
89
from megatron.training.global_vars import _build_tokenizer, set_args
910

@@ -54,6 +55,10 @@ def _initialize_distributed(args, get_embedding_ranks=None, get_position_embeddi
5455

5556
def init(args):
5657
set_args(args)
58+
if args.enable_experimental:
59+
logger.info("Enable megatron experimental")
60+
set_experimental_flag(True)
61+
5762
# Pytorch distributed.
5863
_initialize_distributed(args)
5964

slime/rollout/sglang_rollout.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ async def generate_and_rm(
214214
sampling_params: dict[str, Any],
215215
evaluation: bool = False,
216216
) -> Sample | list[Sample]:
217+
# mask previous off-policy generation for partial rollout
218+
if args.partial_rollout and args.mask_offpolicy_in_partial_rollout and sample.response_length > 0:
219+
sample.loss_mask = [0] * sample.response_length
220+
217221
# For samples with existing response, check if they're complete
218222
if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED:
219223
assert sample.response is not None

slime/utils/arguments.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,15 @@ def add_rollout_arguments(parser):
316316
"This is useful for long responses."
317317
),
318318
)
319+
parser.add_argument(
320+
"--mask-offpolicy-in-partial-rollout",
321+
action="store_true",
322+
default=False,
323+
help=(
324+
"Whether to mask previous generation in partial rollout. "
325+
"If set, only on-policy generated tokens will be used in training"
326+
),
327+
)
319328
parser.add_argument(
320329
"--custom-generate-function-path",
321330
type=str,
@@ -600,6 +609,12 @@ def add_eval_arguments(parser):
600609
"When provided, this overrides --eval-prompt-data."
601610
),
602611
)
612+
parser.add_argument(
613+
"--skip-eval-before-train",
614+
action="store_true",
615+
default=False,
616+
help="Whether to skip evaluation before training.",
617+
)
603618

604619
# The following keys are used to override the rollout version during eval.
605620
parser.add_argument("--eval-input-key", type=str, default=None, help="JSON dataset key")
@@ -922,6 +937,12 @@ def add_wandb_arguments(parser):
922937
"Specify the key in the reward dict using this argument.",
923938
),
924939
)
940+
parser.add_argument(
941+
"--log-correct-samples",
942+
action="store_true",
943+
default=False,
944+
help="Whether to turn on passrate logging, which will log the pass@n of the responses in the rollout.",
945+
)
925946
parser.add_argument("--wandb-run-id", type=str, default=None)
926947
return parser
927948

slime_plugins/models/hf_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from megatron.core import mpu, tensor_parallel
66
from megatron.core.inference.contexts import BaseInferenceContext
77
from megatron.core.packed_seq_params import PackedSeqParams
8-
from megatron.core.process_groups_config import ProcessGroupCollection
98
from megatron.core.transformer.module import MegatronModule
109
from transformers import AutoConfig
1110

@@ -23,7 +22,7 @@ def __init__(
2322
config,
2423
layer_number: int,
2524
cp_comm_type: str = "p2p",
26-
pg_collection: ProcessGroupCollection = None,
25+
pg_collection=None,
2726
):
2827
super().__init__(config=config)
2928
self.args = args

slime_plugins/models/qwen3_next.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
7-
from megatron.core.process_groups_config import ProcessGroupCollection
87
from megatron.core.transformer.spec_utils import ModuleSpec
98
from megatron.core.transformer.transformer_block import get_num_layers_to_build
109
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
@@ -170,7 +169,7 @@ def __init__(
170169
config,
171170
layer_number: int,
172171
cp_comm_type: str = "p2p",
173-
pg_collection: ProcessGroupCollection = None,
172+
pg_collection=None,
174173
):
175174
super().__init__(
176175
args,

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def onload_rollout():
6262
# train loop.
6363
# note that for async training, one can change the position of the sync operation(ray.get).
6464
for rollout_id in range(args.start_rollout_id, args.num_rollout):
65-
if args.eval_interval is not None and rollout_id == 0:
65+
if args.eval_interval is not None and rollout_id == 0 and not args.skip_eval_before_train:
6666
ray.get(rollout_manager.eval.remote(rollout_id))
6767

6868
rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))

0 commit comments

Comments
 (0)