@@ -1157,23 +1157,14 @@ def _compute_loss_and_metrics(self, model, inputs):
11571157 rollout_per_token_logps ,
11581158 completion_mask )
11591159
1160- # Apply importance sampling correction if mode is enabled
1161- if self .rollout_importance_sampling_mode is not None :
1162- # Compute the log ratio between policy model and rollout model
1163- # log π_θ(y|x) - log π_rollout(y|x)
1164- rollout_log_ratio = old_per_token_logps - rollout_per_token_logps
1165-
1166- # Apply importance sampling correction based on mode
1167- rollout_is_weights = self ._apply_rollout_importance_sampling (rollout_log_ratio , completion_mask )
1168-
1169- # Compute additional IS-specific metrics (ESS, clipped_frac, is_weight_mean)
1160+ rollout_log_ratio , rollout_is_weights = self ._get_rollout_is_correction (old_per_token_logps ,
1161+ rollout_per_token_logps ,
1162+ completion_mask )
1163+ if rollout_log_ratio is not None :
11701164 is_metrics = self ._compute_is_correction_metrics (rollout_log_ratio , rollout_is_weights , completion_mask )
11711165 rollout_correction_metrics .update (is_metrics )
11721166
1173- # Store IS weights for loss computation
1174- inputs ['rollout_is_weights' ] = rollout_is_weights
1175- else :
1176- inputs ['rollout_is_weights' ] = None
1167+ inputs ['rollout_is_weights' ] = rollout_is_weights
11771168 else :
11781169 inputs ['rollout_is_weights' ] = None
11791170
@@ -1814,44 +1805,46 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep):
18141805 last_hidden_state = last_hidden_state [:, - logits_to_keep :, :] # (B, logits_to_keep, H)
18151806 return last_hidden_state
18161807
1817- def _compute_rollout_is_ratio_for_liger (self , inputs , completion_mask ):
1818- """Compute rollout importance sampling ratio for liger loss path .
1808+ def _get_rollout_is_correction (self , old_per_token_logps , rollout_per_token_logps , completion_mask ):
1809+ """Compute rollout importance sampling log- ratio and IS weights .
18191810
1820- Returns the IS ratio tensor or None if not applicable.
1811+ Returns:
1812+ (rollout_log_ratio, rollout_is_weights) if rollout IS correction is applicable,
1813+ (None, None) otherwise.
18211814 """
18221815 if self .rollout_importance_sampling_mode is None or self .disable_rollout_importance_sampling :
1823- return None
1824- rollout_per_token_logps = inputs .get ('rollout_per_token_logps' )
1825- old_per_token_logps = inputs .get ('old_per_token_logps' )
1826- if rollout_per_token_logps is None or old_per_token_logps is None :
1827- return None
1816+ return None , None
18281817
18291818 rollout_log_ratio = old_per_token_logps - rollout_per_token_logps
18301819 rollout_is_weights = self ._apply_rollout_importance_sampling (rollout_log_ratio , completion_mask )
1831- return rollout_is_weights
1820+ return rollout_log_ratio , rollout_is_weights
18321821
18331822 def compute_liger_loss (self , unwrapped_model , inputs ):
1834- # Compute the per-token log probabilities for the model
18351823 assert not self .template .padding_free
18361824 assert self .advantage_estimator == 'grpo'
18371825 input_ids = inputs ['input_ids' ]
18381826 logits_to_keep = inputs ['logits_to_keep' ]
18391827 completion_ids = input_ids [:, - logits_to_keep :]
18401828 completion_mask = inputs ['completion_mask' ]
18411829
1842- # get the last hidden state of the model
18431830 last_hidden_state = self ._get_last_hidden_state (unwrapped_model , inputs , logits_to_keep )
1844- # compute loss and metrics using liger grpo loss
18451831
1846- vllm_is_ratio = self ._compute_rollout_is_ratio_for_liger (inputs , completion_mask )
1832+ old_per_token_logps = inputs .get ('old_per_token_logps' )
1833+ local_has = inputs .get ('rollout_per_token_logps' ) is not None
1834+ vllm_is_ratio = None
1835+ if all (gather_object ([local_has ])):
1836+ rollout_per_token_logps = inputs ['rollout_per_token_logps' ]
1837+ _ , vllm_is_ratio = self ._get_rollout_is_correction (old_per_token_logps , rollout_per_token_logps ,
1838+ completion_mask )
1839+
18471840 loss , metrics = self .liger_grpo_loss (
18481841 _input = last_hidden_state ,
18491842 lin_weight = unwrapped_model .lm_head .weight ,
18501843 selected_token_ids = completion_ids ,
18511844 attention_mask = completion_mask ,
18521845 advantages = inputs ['advantages' ],
18531846 bias = unwrapped_model .lm_head .bias ,
1854- old_per_token_logps = inputs . get ( ' old_per_token_logps' ) ,
1847+ old_per_token_logps = old_per_token_logps ,
18551848 ref_per_token_logps = inputs .get ('ref_per_token_logps' ),
18561849 vllm_is_ratio = vllm_is_ratio ,
18571850 )
0 commit comments