Skip to content

Commit 555fb39

Browse files
hjh0119Jintao-Huang
authored andcommitted
[bugfix] fix mllm grpo liger loss (#9406)
1 parent 1efd051 commit 555fb39

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

swift/rlhf_trainers/grpo_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
10011001
inputs = inputs[0]
10021002
if self.use_liger_loss:
10031003
unwrapped_model = self.accelerator.unwrap_model(model)
1004-
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
1004+
forward_kwargs = self._prepare_model_inputs(inputs)
1005+
return self._forward_redirection(model, unwrapped_model,
1006+
lambda *_, **__: self.compute_liger_loss(unwrapped_model, inputs),
1007+
**forward_kwargs)
10051008
else:
10061009
return self._compute_loss(model, inputs)
10071010

0 commit comments

Comments
 (0)