Skip to content

Commit d4951d5

Browse files
committed
clean
1 parent 452a800 commit d4951d5

File tree

1 file changed

+21
-28
lines changed

1 file changed

+21
-28
lines changed

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,23 +1157,14 @@ def _compute_loss_and_metrics(self, model, inputs):
11571157
rollout_per_token_logps,
11581158
completion_mask)
11591159

1160-
# Apply importance sampling correction if mode is enabled
1161-
if self.rollout_importance_sampling_mode is not None:
1162-
# Compute the log ratio between policy model and rollout model
1163-
# log π_θ(y|x) - log π_rollout(y|x)
1164-
rollout_log_ratio = old_per_token_logps - rollout_per_token_logps
1165-
1166-
# Apply importance sampling correction based on mode
1167-
rollout_is_weights = self._apply_rollout_importance_sampling(rollout_log_ratio, completion_mask)
1168-
1169-
# Compute additional IS-specific metrics (ESS, clipped_frac, is_weight_mean)
1160+
rollout_log_ratio, rollout_is_weights = self._get_rollout_is_correction(old_per_token_logps,
1161+
rollout_per_token_logps,
1162+
completion_mask)
1163+
if rollout_log_ratio is not None:
11701164
is_metrics = self._compute_is_correction_metrics(rollout_log_ratio, rollout_is_weights, completion_mask)
11711165
rollout_correction_metrics.update(is_metrics)
11721166

1173-
# Store IS weights for loss computation
1174-
inputs['rollout_is_weights'] = rollout_is_weights
1175-
else:
1176-
inputs['rollout_is_weights'] = None
1167+
inputs['rollout_is_weights'] = rollout_is_weights
11771168
else:
11781169
inputs['rollout_is_weights'] = None
11791170

@@ -1814,44 +1805,46 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
18141805
last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
18151806
return last_hidden_state
18161807

1817-
def _compute_rollout_is_ratio_for_liger(self, inputs, completion_mask):
1818-
"""Compute rollout importance sampling ratio for liger loss path.
1808+
def _get_rollout_is_correction(self, old_per_token_logps, rollout_per_token_logps, completion_mask):
1809+
"""Compute rollout importance sampling log-ratio and IS weights.
18191810
1820-
Returns the IS ratio tensor or None if not applicable.
1811+
Returns:
1812+
(rollout_log_ratio, rollout_is_weights) if rollout IS correction is applicable,
1813+
(None, None) otherwise.
18211814
"""
18221815
if self.rollout_importance_sampling_mode is None or self.disable_rollout_importance_sampling:
1823-
return None
1824-
rollout_per_token_logps = inputs.get('rollout_per_token_logps')
1825-
old_per_token_logps = inputs.get('old_per_token_logps')
1826-
if rollout_per_token_logps is None or old_per_token_logps is None:
1827-
return None
1816+
return None, None
18281817

18291818
rollout_log_ratio = old_per_token_logps - rollout_per_token_logps
18301819
rollout_is_weights = self._apply_rollout_importance_sampling(rollout_log_ratio, completion_mask)
1831-
return rollout_is_weights
1820+
return rollout_log_ratio, rollout_is_weights
18321821

18331822
def compute_liger_loss(self, unwrapped_model, inputs):
1834-
# Compute the per-token log probabilities for the model
18351823
assert not self.template.padding_free
18361824
assert self.advantage_estimator == 'grpo'
18371825
input_ids = inputs['input_ids']
18381826
logits_to_keep = inputs['logits_to_keep']
18391827
completion_ids = input_ids[:, -logits_to_keep:]
18401828
completion_mask = inputs['completion_mask']
18411829

1842-
# get the last hidden state of the model
18431830
last_hidden_state = self._get_last_hidden_state(unwrapped_model, inputs, logits_to_keep)
1844-
# compute loss and metrics using liger grpo loss
18451831

1846-
vllm_is_ratio = self._compute_rollout_is_ratio_for_liger(inputs, completion_mask)
1832+
old_per_token_logps = inputs.get('old_per_token_logps')
1833+
local_has = inputs.get('rollout_per_token_logps') is not None
1834+
vllm_is_ratio = None
1835+
if all(gather_object([local_has])):
1836+
rollout_per_token_logps = inputs['rollout_per_token_logps']
1837+
_, vllm_is_ratio = self._get_rollout_is_correction(old_per_token_logps, rollout_per_token_logps,
1838+
completion_mask)
1839+
18471840
loss, metrics = self.liger_grpo_loss(
18481841
_input=last_hidden_state,
18491842
lin_weight=unwrapped_model.lm_head.weight,
18501843
selected_token_ids=completion_ids,
18511844
attention_mask=completion_mask,
18521845
advantages=inputs['advantages'],
18531846
bias=unwrapped_model.lm_head.bias,
1854-
old_per_token_logps=inputs.get('old_per_token_logps'),
1847+
old_per_token_logps=old_per_token_logps,
18551848
ref_per_token_logps=inputs.get('ref_per_token_logps'),
18561849
vllm_is_ratio=vllm_is_ratio,
18571850
)

0 commit comments

Comments
 (0)