Skip to content

Commit c6ba986

Browse files
committed
bump liger
1 parent de6e595 commit c6ba986

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

swift/arguments/rlhf_args.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,8 @@ def _check_grpo(self):
458458
'Please set NPROC_PER_NODE equal to num_processes.')
459459
if self.use_liger_kernel:
460460
liger_kernel_version = version.parse(importlib.metadata.version('liger-kernel'))
461+
if liger_kernel_version < version.parse('0.7.0'):
462+
raise ValueError('Please update liger-kernel to 0.7.0 or later: pip install -U liger-kernel')
461463
if self.delta is not None:
462464
raise ValueError('Liger loss does not support two-sided GRPO loss yet.')
463465
if self.sequence_parallel_size > 1:
@@ -468,18 +470,15 @@ def _check_grpo(self):
468470
raise ValueError('Liger loss does not support entropy mask yet.')
469471
if self.log_entropy:
470472
raise ValueError('Liger loss does not support log entropy yet.')
473+
if self.off_policy_sequence_mask_delta is not None:
474+
raise ValueError('Liger loss does not support off-policy sequence masking yet.')
471475
if self.importance_sampling_level != 'token':
472-
if liger_kernel_version < version.parse('0.6.3'):
473-
raise ValueError('Please update liger-kernel to 0.6.3 or later')
474476
if self.importance_sampling_level == 'sequence_token':
475477
self.importance_sampling_level = 'sequence'
476478
logger.info('Remapping `importance_sampling_level` from `sequence_token` to `sequence` for '
477479
'liger-kernel compatibility. The two methods are computationally equivalent.')
478480
if self.advantage_estimator != 'grpo':
479481
raise ValueError('Liger loss currently only support grpo advantage estimator')
480-
from trl.import_utils import is_liger_kernel_available
481-
assert is_liger_kernel_available(), (
482-
'Please install/update liger-kernel by running: pip install -U liger-kernel')
483482

484483
if self.async_generate and self.multi_turn_scheduler is not None:
485484
raise NotImplementedError('Currently, async_generate is not supported with multi-turn functionality.')

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,6 +1814,22 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
18141814
last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
18151815
return last_hidden_state
18161816

1817+
def _compute_rollout_is_ratio_for_liger(self, inputs, completion_mask):
1818+
"""Compute rollout importance sampling ratio for liger loss path.
1819+
1820+
Returns the IS ratio tensor or None if not applicable.
1821+
"""
1822+
if self.rollout_importance_sampling_mode is None or self.disable_rollout_importance_sampling:
1823+
return None
1824+
rollout_per_token_logps = inputs.get('rollout_per_token_logps')
1825+
old_per_token_logps = inputs.get('old_per_token_logps')
1826+
if rollout_per_token_logps is None or old_per_token_logps is None:
1827+
return None
1828+
1829+
rollout_log_ratio = old_per_token_logps - rollout_per_token_logps
1830+
rollout_is_weights = self._apply_rollout_importance_sampling(rollout_log_ratio, completion_mask)
1831+
return rollout_is_weights
1832+
18171833
def compute_liger_loss(self, unwrapped_model, inputs):
18181834
# Compute the per-token log probabilities for the model
18191835
assert not self.template.padding_free
@@ -1826,6 +1842,12 @@ def compute_liger_loss(self, unwrapped_model, inputs):
18261842
# get the last hidden state of the model
18271843
last_hidden_state = self._get_last_hidden_state(unwrapped_model, inputs, logits_to_keep)
18281844
# compute loss and metrics using liger grpo loss
1845+
1846+
kwargs = {}
1847+
vllm_is_ratio = self._compute_rollout_is_ratio_for_liger(inputs, completion_mask)
1848+
if vllm_is_ratio is not None:
1849+
kwargs['vllm_is_ratio'] = vllm_is_ratio
1850+
18291851
loss, metrics = self.liger_grpo_loss(
18301852
_input=last_hidden_state,
18311853
lin_weight=unwrapped_model.lm_head.weight,
@@ -1835,9 +1857,9 @@ def compute_liger_loss(self, unwrapped_model, inputs):
18351857
bias=unwrapped_model.lm_head.bias,
18361858
old_per_token_logps=inputs.get('old_per_token_logps'),
18371859
ref_per_token_logps=inputs.get('ref_per_token_logps'),
1860+
**kwargs,
18381861
)
1839-
# Extract metrics from the liger_grpo_loss output
1840-
# KL divergence is the first metric when beta is non-zero
1862+
18411863
mean_kl = metrics[0] if self.beta != 0.0 else None
18421864
clip_ratio = metrics[-1]
18431865

@@ -2097,9 +2119,6 @@ def _prepare_liger_loss(self):
20972119
self.use_liger_loss = self.args.use_liger_kernel
20982120
if self.use_liger_loss:
20992121
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
2100-
kwargs = {}
2101-
if 'importance_sampling_level' in inspect.signature(LigerFusedLinearGRPOLoss.__init__).parameters:
2102-
kwargs['importance_sampling_level'] = self.importance_sampling_level
21032122
self.liger_grpo_loss = LigerFusedLinearGRPOLoss(
21042123
beta=self.beta,
21052124
epsilon_low=self.epsilon_low,
@@ -2108,7 +2127,9 @@ def _prepare_liger_loss(self):
21082127
use_ref_model=self.beta != 0.0,
21092128
loss_type=self.loss_type,
21102129
max_completion_length=self.max_completion_length,
2111-
**kwargs,
2130+
importance_sampling_level=self.importance_sampling_level,
2131+
sapo_temperature_pos=self.tau_pos,
2132+
sapo_temperature_neg=self.tau_neg,
21122133
)
21132134
self._forward_redirection = _ForwardRedirection()
21142135

0 commit comments

Comments
 (0)