@@ -1814,6 +1814,22 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
18141814 last_hidden_state = last_hidden_state [:, - logits_to_keep :, :] # (B, logits_to_keep, H)
18151815 return last_hidden_state
18161816
1817+ def _compute_rollout_is_ratio_for_liger (self , inputs , completion_mask ):
1818+ """Compute rollout importance sampling ratio for liger loss path.
1819+
1820+ Returns the IS ratio tensor or None if not applicable.
1821+ """
1822+ if self .rollout_importance_sampling_mode is None or self .disable_rollout_importance_sampling :
1823+ return None
1824+ rollout_per_token_logps = inputs .get ('rollout_per_token_logps' )
1825+ old_per_token_logps = inputs .get ('old_per_token_logps' )
1826+ if rollout_per_token_logps is None or old_per_token_logps is None :
1827+ return None
1828+
1829+ rollout_log_ratio = old_per_token_logps - rollout_per_token_logps
1830+ rollout_is_weights = self ._apply_rollout_importance_sampling (rollout_log_ratio , completion_mask )
1831+ return rollout_is_weights
1832+
18171833 def compute_liger_loss (self , unwrapped_model , inputs ):
18181834 # Compute the per-token log probabilities for the model
18191835 assert not self .template .padding_free
@@ -1826,6 +1842,12 @@ def compute_liger_loss(self, unwrapped_model, inputs):
18261842 # get the last hidden state of the model
18271843 last_hidden_state = self ._get_last_hidden_state (unwrapped_model , inputs , logits_to_keep )
18281844 # compute loss and metrics using liger grpo loss
1845+
1846+ kwargs = {}
1847+ vllm_is_ratio = self ._compute_rollout_is_ratio_for_liger (inputs , completion_mask )
1848+ if vllm_is_ratio is not None :
1849+ kwargs ['vllm_is_ratio' ] = vllm_is_ratio
1850+
18291851 loss , metrics = self .liger_grpo_loss (
18301852 _input = last_hidden_state ,
18311853 lin_weight = unwrapped_model .lm_head .weight ,
@@ -1835,9 +1857,9 @@ def compute_liger_loss(self, unwrapped_model, inputs):
18351857 bias = unwrapped_model .lm_head .bias ,
18361858 old_per_token_logps = inputs .get ('old_per_token_logps' ),
18371859 ref_per_token_logps = inputs .get ('ref_per_token_logps' ),
1860+ ** kwargs ,
18381861 )
1839- # Extract metrics from the liger_grpo_loss output
1840- # KL divergence is the first metric when beta is non-zero
1862+
18411863 mean_kl = metrics [0 ] if self .beta != 0.0 else None
18421864 clip_ratio = metrics [- 1 ]
18431865
@@ -2097,9 +2119,6 @@ def _prepare_liger_loss(self):
20972119 self .use_liger_loss = self .args .use_liger_kernel
20982120 if self .use_liger_loss :
20992121 from liger_kernel .chunked_loss import LigerFusedLinearGRPOLoss
2100- kwargs = {}
2101- if 'importance_sampling_level' in inspect .signature (LigerFusedLinearGRPOLoss .__init__ ).parameters :
2102- kwargs ['importance_sampling_level' ] = self .importance_sampling_level
21032122 self .liger_grpo_loss = LigerFusedLinearGRPOLoss (
21042123 beta = self .beta ,
21052124 epsilon_low = self .epsilon_low ,
@@ -2108,7 +2127,9 @@ def _prepare_liger_loss(self):
21082127 use_ref_model = self .beta != 0.0 ,
21092128 loss_type = self .loss_type ,
21102129 max_completion_length = self .max_completion_length ,
2111- ** kwargs ,
2130+ importance_sampling_level = self .importance_sampling_level ,
2131+ sapo_temperature_pos = self .tau_pos ,
2132+ sapo_temperature_neg = self .tau_neg ,
21122133 )
21132134 self ._forward_redirection = _ForwardRedirection ()
21142135
0 commit comments