generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Open
Labels
⚡ PEFTRelated to PEFTRelated to PEFT🏋 GRPORelated to GRPORelated to GRPO🐛 bugSomething isn't workingSomething isn't working
Description
I noticed a compatibility issue in GRPOTrainer when using use_liger_loss=True combined with a PEFT (LoRA) model where the lm_head is targeted for training.
In the compute_liger_loss method, the code directly passes the lm_head.weight of the unwrapped model to the LigerFusedLinearGRPOLoss.
def compute_liger_loss(self, unwrapped_model, inputs):
# ...
last_hidden_state = self._get_last_hidden_state(...)
# Issue: unwrapped_model.lm_head.weight is the frozen base weight when LoRA is active
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
# ...
)
# ...
If the user configures LoRA with target_modules including "lm_head", the unwrapped_model.lm_head becomes a LoraLayer.
In this case:
- unwrapped_model.lm_head.weight refers to the frozen base model weights.
- The trainable parameters are in the separate LoRA adapters (lora_A, lora_B).
- Since LigerFusedLinearGRPOLoss computes the operation using only the provided lin_weight, the calculation effectively ignores the LoRA adapters.
The LoRA adapters attached to the lm_head will not contribute to the loss calculation and, consequently, will not receive correct gradient updates. The model will behave as if the lm_head is frozen, even though the user intended to train it.
If lm_head is adapted via PEFT, GRPOTrainer should either:
- Raise a warning or error preventing the usage of use_liger_loss=True when lm_head is in target_modules.
- Or, handle the merging of weights before passing them to the Liger kernel (though this might negate the memory benefits of using Liger).
Metadata
Metadata
Assignees
Labels
⚡ PEFTRelated to PEFTRelated to PEFT🏋 GRPORelated to GRPORelated to GRPO🐛 bugSomething isn't workingSomething isn't working