Skip to content

Commit e685491

Browse files
committed
fix
1 parent 2f1e21d commit e685491

File tree

4 files changed

+105
-108
lines changed

4 files changed

+105
-108
lines changed

examples/megatron/rlhf/gkd/opsd.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@ megatron rlhf \
1515
--model Qwen/Qwen3-4B \
1616
--external_plugins examples/train/rlhf/opsd/opsd_plugin.py \
1717
--dataset 'open-r1/OpenThoughts-114k-math' \
18+
--use_vllm true \
19+
--vllm_mode colocate \
20+
--vllm_gpu_memory_utilization 0.6 \
21+
--vllm_max_model_len 10240 \
22+
--sleep_level 1 \
1823
--lmbda 1.0 \
1924
--beta 0.5 \
2025
--temperature 1.2 \
2126
--sft_alpha 0 \
2227
--torch_dtype bfloat16 \
2328
--micro_batch_size 1 \
2429
--global_batch_size 32 \
25-
--max_steps 1000 \
30+
--train_iters 1000 \
2631
--lr 2e-5 \
2732
--save_steps 100 \
2833
--save_total_limit 10 \

swift/megatron/trainers/grpo_trainer.py

Lines changed: 48 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
279279

280280
truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch], dtype=torch.bool, device=self.device)
281281

282+
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
283+
282284
if template.padding_free:
283285
# In padding_free mode, labels shape is [1, total_seq_len] (rmpad format)
284286
# Calculate seq_lengths from cu_seq_lens or position_ids
@@ -290,7 +292,7 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
290292
max_seq_len = seq_lengths.max().item()
291293

292294
# completion_mask in rmpad format [1, total_tokens]
293-
completion_mask_rmpad = (labels != -100).float()
295+
completion_mask_rmpad = (rolled_labels != -100).float()
294296
completion_mask, _ = pad_logps_back_to_batch(
295297
logps_rmpad=completion_mask_rmpad,
296298
logits_to_keep=max_seq_len,
@@ -312,8 +314,8 @@ def _get_encoded_batch(self, encoded_list, rollout_batch, template):
312314
seq_lengths = torch.full((batch_size, ), labels.shape[-1], dtype=torch.int64, device=self.device)
313315
max_seq_len = labels.shape[-1]
314316

315-
# completion_mask is already [batch_size, seq_len] in non-padding_free mode
316-
completion_mask = (labels != -100)
317+
# completion_mask based on rolled labels for alignment with per_token_logps
318+
completion_mask = (rolled_labels != -100)
317319

318320
encoded_batch.update({
319321
'completion_mask': completion_mask, # [batch_size, max_seq_len]
@@ -934,34 +936,31 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
934936

935937
inputs = self._prepare_model_inputs(batch)
936938
if self.beta != 0.0:
937-
with torch.no_grad(), self.null_ref_context() as ref_models:
939+
with self.null_ref_context() as ref_models:
938940
assert len(ref_models) == 1, 'GRPO currently does not support VPP.'
939941
ref_model = ref_models[0]
940-
ref_per_token_logps_raw = self.model_forward(
941-
ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
942+
ref_per_token_logps_packed = self.compute_per_token_logps(
943+
ref_model, iter([deepcopy(inputs)]), temperature=self.temperature)
942944
if self.template.padding_free:
943-
# In padding_free mode, logps are in rmpad format [1, total_tokens]
944-
# Pad to batch format [batch_size, max_seq_len]
945945
ref_per_token_logps, _ = pad_logps_back_to_batch(
946-
logps_rmpad=ref_per_token_logps_raw,
946+
logps_rmpad=ref_per_token_logps_packed,
947947
logits_to_keep=max_seq_len,
948948
batch_size=batch_size,
949949
seq_lengths=seq_lengths)
950950
else:
951-
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
952-
ref_per_token_logps = ref_per_token_logps_raw
951+
ref_per_token_logps = ref_per_token_logps_packed
953952
batch['ref_per_token_logps'] = ref_per_token_logps
954953

955-
old_per_token_logps_raw = self.model_forward(
956-
self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
954+
old_per_token_logps_packed = self.compute_per_token_logps(
955+
self.unwrapped_models[0], iter([deepcopy(inputs)]), temperature=self.temperature)
957956
if self.template.padding_free:
958957
old_per_token_logps, _ = pad_logps_back_to_batch(
959-
logps_rmpad=old_per_token_logps_raw,
958+
logps_rmpad=old_per_token_logps_packed,
960959
logits_to_keep=max_seq_len,
961960
batch_size=batch_size,
962961
seq_lengths=seq_lengths)
963962
else:
964-
old_per_token_logps = old_per_token_logps_raw
963+
old_per_token_logps = old_per_token_logps_packed
965964
batch['old_per_token_logps'] = old_per_token_logps
966965

967966
return batch
@@ -1052,69 +1051,46 @@ def forward_step(self, data_iterator, model):
10521051

10531052
# Check if this is the PP last stage (only last stage has labels and computes loss)
10541053
is_pp_last_stage = mpu.is_pipeline_last_stage()
1055-
1056-
if self.compute_entropy:
1057-
# Forward without labels to get logits, then compute logps and entropy
1058-
inputs_for_logits = {k: v for k, v in inputs.items() if k != 'labels'}
1059-
output_tensor = model(**inputs_for_logits)
1060-
1061-
# Compute per_token_logps and per_token_entropy from logits on PP last stage
1062-
if is_pp_last_stage and output_tensor is not None:
1063-
# output_tensor is logits [batch/1, seq, partition_vocab_size]
1064-
per_token_logps_raw, per_token_entropy_raw = compute_logps_and_entropy_from_logits(
1065-
output_tensor, labels, compute_entropy=True)
1066-
1067-
# In CP mode, all_gather and reconstruct full sequence
1068-
if args.context_parallel_size > 1:
1069-
num_samples = packed_seq_params.num_samples if args.padding_free else micro_batch_size
1070-
per_token_logps_raw = self._postprocess_packed_tensor_cp(per_token_logps_raw, packed_seq_params,
1071-
num_samples)
1072-
per_token_entropy_raw = self._postprocess_packed_tensor_cp(per_token_entropy_raw, packed_seq_params,
1073-
num_samples)
1074-
1075-
if args.padding_free:
1076-
# Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len]
1077-
per_token_logps, _ = pad_logps_back_to_batch(
1078-
logps_rmpad=per_token_logps_raw,
1079-
logits_to_keep=max_seq_len,
1080-
batch_size=micro_batch_size,
1081-
seq_lengths=seq_lengths)
1054+
inputs_for_logits = {k: v for k, v in inputs.items() if k != 'labels'}
1055+
logits_packed = model(**inputs_for_logits)
1056+
output_tensor = None
1057+
if is_pp_last_stage and logits_packed is not None:
1058+
if self.temperature != 1.0:
1059+
logits_packed.div_(self.temperature)
1060+
per_token_logps_packed, per_token_entropy_packed = compute_logps_and_entropy_from_logits(
1061+
logits_packed, labels, compute_entropy=self.compute_entropy)
1062+
1063+
# In CP mode, all_gather and reconstruct full sequence
1064+
if args.context_parallel_size > 1:
1065+
num_samples = packed_seq_params.num_samples if args.padding_free else micro_batch_size
1066+
per_token_logps_packed = self._postprocess_packed_tensor_cp(per_token_logps_packed, packed_seq_params,
1067+
num_samples)
1068+
if per_token_entropy_packed is not None:
1069+
per_token_entropy_packed = self._postprocess_packed_tensor_cp(per_token_entropy_packed,
1070+
packed_seq_params, num_samples)
1071+
1072+
if args.padding_free:
1073+
# Pad from rmpad [1, total_tokens] to batch format [batch_size, max_seq_len]
1074+
per_token_logps, _ = pad_logps_back_to_batch(
1075+
logps_rmpad=per_token_logps_packed,
1076+
logits_to_keep=max_seq_len,
1077+
batch_size=micro_batch_size,
1078+
seq_lengths=seq_lengths)
1079+
if per_token_entropy_packed is not None:
10821080
per_token_entropy, _ = pad_logps_back_to_batch(
1083-
logps_rmpad=per_token_entropy_raw,
1081+
logps_rmpad=per_token_entropy_packed,
10841082
logits_to_keep=max_seq_len,
10851083
batch_size=micro_batch_size,
10861084
seq_lengths=seq_lengths,
10871085
pad_value=float('nan'))
10881086
else:
1089-
per_token_logps = per_token_logps_raw
1090-
per_token_entropy = per_token_entropy_raw
1091-
1092-
data['per_token_logps'] = per_token_logps
1093-
data['per_token_entropy'] = per_token_entropy
1094-
else:
1095-
# Standard forward with labels, returns per-token loss (more efficient)
1096-
output_tensor = model(**inputs)
1097-
1098-
# Convert output_tensor (per-token loss) to per_token_logps on PP last stage
1099-
if is_pp_last_stage and output_tensor is not None:
1100-
per_token_logps_raw = self.get_logps(
1101-
output_tensor,
1102-
labels,
1103-
packed_seq_params,
1104-
packed_seq_params.num_samples if args.padding_free else micro_batch_size,
1105-
per_token=True)
1106-
1107-
if args.padding_free:
1108-
per_token_logps, _ = pad_logps_back_to_batch(
1109-
logps_rmpad=per_token_logps_raw,
1110-
logits_to_keep=max_seq_len,
1111-
batch_size=micro_batch_size,
1112-
seq_lengths=seq_lengths)
1113-
else:
1114-
per_token_logps = per_token_logps_raw
1087+
per_token_entropy = None
1088+
else:
1089+
per_token_logps = per_token_logps_packed
1090+
per_token_entropy = per_token_entropy_packed
11151091

1116-
data['per_token_logps'] = per_token_logps
1117-
data['per_token_entropy'] = None
1092+
output_tensor = per_token_logps
1093+
data['per_token_entropy'] = per_token_entropy
11181094

11191095
return output_tensor, partial(self.loss_func, data=data)
11201096

@@ -1129,7 +1105,7 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
11291105

11301106
# Get pre-computed per_token_logps and per_token_entropy from forward_step
11311107
# These are already in batch format [batch_size, max_seq_len]
1132-
per_token_logps = data.get('per_token_logps')
1108+
per_token_logps = output_tensor
11331109
per_token_entropy = data.get('per_token_entropy')
11341110

11351111
# Get pre-padded ref/old/rollout logps from data
@@ -1409,38 +1385,6 @@ def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
14091385

14101386
return loss, reporting_metric
14111387

1412-
def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
1413-
"""Forward pass through model to compute logps.
1414-
1415-
Args:
1416-
model: The model to forward
1417-
data_iterator: Iterator providing batch data
1418-
no_grad: Whether to use torch.no_grad() context
1419-
per_token: Whether to return per-token logps
1420-
1421-
Returns:
1422-
data dict containing 'logps'
1423-
"""
1424-
# used to calculate model forward (logps) in GRPO
1425-
data = self.get_batch(data_iterator)
1426-
data.pop('loss_scale', None)
1427-
input_ids = data.get('input_ids')
1428-
labels = data.get('labels')
1429-
context = torch.no_grad() if no_grad else nullcontext()
1430-
1431-
with context:
1432-
output_tensor = forward_step_helper(self.args, model, data)
1433-
1434-
# packed_seq_params only exists in padding_free mode
1435-
packed_seq_params = data.get('packed_seq_params')
1436-
if packed_seq_params is not None:
1437-
num_samples = packed_seq_params.num_samples
1438-
else:
1439-
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]
1440-
data['logps'] = None if labels is None else self.get_logps(
1441-
output_tensor, labels, packed_seq_params, num_samples, per_token=per_token)
1442-
return data
1443-
14441388
def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]:
14451389
"""Convert raw input data into RolloutInferRequest objects"""
14461390

swift/megatron/trainers/rlhf_mixin.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import torch
3-
from contextlib import contextmanager
3+
from contextlib import contextmanager, nullcontext
44
from megatron.core import mpu
55
from torch.distributed.nn import all_reduce
66
from transformers.utils import ContextManagers
77

88
from swift.megatron.model import get_mcore_model
9-
from swift.megatron.utils import load_mcore_checkpoint
9+
from swift.megatron.utils import forward_step_helper, load_mcore_checkpoint
1010
from swift.rlhf_trainers.utils import identity_data_collator
1111
from swift.utils import get_logger
1212
from .base import BaseMegatronTrainer
13+
from .vocab_parallel_utils import compute_logps_and_entropy_from_logits
1314

1415
logger = get_logger()
1516

@@ -91,6 +92,49 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples, per_t
9192
all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group())
9293
return all_logps
9394

95+
def compute_per_token_logps(self, model, data_iterator, no_grad=True, temperature=1.0):
96+
"""Forward pass to get logits, then compute temperature-scaled per-token logps.
97+
98+
Unlike get_logps (which recovers logps from cross-entropy loss), this method
99+
obtains raw logits from the model and computes logps with temperature scaling,
100+
which is required for importance sampling in GRPO and potentially other algorithms.
101+
102+
Args:
103+
model: The model to forward
104+
data_iterator: Iterator providing batch data
105+
no_grad: Whether to disable gradient computation (default: True)
106+
temperature: Temperature for scaling logits before log_softmax
107+
108+
Returns:
109+
per_token_logps tensor, or None if on a non-last PP stage
110+
"""
111+
data = self.get_batch(data_iterator)
112+
data.pop('loss_scale', None)
113+
labels = data.get('labels')
114+
115+
data_for_forward = {k: v for k, v in data.items() if k != 'labels'}
116+
context = torch.no_grad() if no_grad else nullcontext()
117+
with context:
118+
output_tensor = forward_step_helper(self.args, model, data_for_forward)
119+
120+
if labels is None or output_tensor is None:
121+
return None
122+
123+
if temperature != 1.0:
124+
output_tensor.div_(temperature)
125+
per_token_logps, _ = compute_logps_and_entropy_from_logits(output_tensor, labels)
126+
127+
packed_seq_params = data.get('packed_seq_params')
128+
if packed_seq_params is not None:
129+
num_samples = packed_seq_params.num_samples
130+
else:
131+
input_ids = data.get('input_ids')
132+
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]
133+
134+
if self.args.context_parallel_size > 1:
135+
per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples)
136+
return per_token_logps
137+
94138
def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
95139
"""
96140
Generic method: In CP mode, all_gather and reconstruct full tensor sequences.

swift/megatron/trainers/vocab_parallel_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,12 @@ def compute_logps_and_entropy_from_logits(
207207
Note: In Megatron, labels are already shifted (via torch.roll in get_batch_on_this_tp_rank),
208208
so logits and labels are already aligned. No additional shift is needed here.
209209
210+
Temperature scaling should be applied by the caller before invoking this function,
211+
so that this function remains a pure computation without side effects on the input.
212+
210213
Args:
211-
logits: Logits tensor [batch, seq, partition_vocab_size] or [1, total_tokens, partition_vocab_size]
214+
logits: Logits tensor [batch, seq, partition_vocab_size] or [1, total_tokens, partition_vocab_size].
215+
Should be pre-scaled by temperature if needed.
212216
labels: Token labels [batch, seq] or [1, total_tokens], -100 for masked positions
213217
compute_entropy: Whether to compute entropy (default: False)
214218
entropy_chunk_size: Chunk size for entropy computation (default: 512)

0 commit comments

Comments
 (0)